From f837c59d60039403c86454b2119ee9557e34ee22 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Thu, 19 Aug 2021 23:46:04 +0200 Subject: [PATCH] Move certificate handling into a separate module. --- go.mod | 1 + go.sum | 4 + webserver/certificate.go | 145 ---------------------------------- webserver/certificate_test.go | 51 ------------ webserver/webserver.go | 7 +- 5 files changed, 11 insertions(+), 197 deletions(-) delete mode 100644 webserver/certificate.go delete mode 100644 webserver/certificate_test.go diff --git a/go.mod b/go.mod index 19f4ac3..28eda31 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/at-wat/ebml-go v0.16.0 github.com/gorilla/websocket v1.4.2 + github.com/jech/cert v0.0.0-20210819231831-aca735647728 github.com/jech/samplebuilder v0.0.0-20210711185346-d34c6dd315fb github.com/pion/ice/v2 v2.1.10 github.com/pion/rtcp v1.2.6 diff --git a/go.sum b/go.sum index e0569b4..ab54b13 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,10 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jech/cert v0.0.0-20210819214059-b84cf5e78f7a h1:3VzfD0n6o2Y3FW6cluxqj535fT9Uq3ANaYkG1JuiADI= +github.com/jech/cert v0.0.0-20210819214059-b84cf5e78f7a/go.mod h1:FXUA/zpiQfV4uBVN2kAwkf3X7pU7l1l2ovS45CsSYZs= +github.com/jech/cert v0.0.0-20210819231831-aca735647728 h1:tN+W1ll2oKuJGMCaO1CRK4rr+xSRjVSfWmnKlACdx38= +github.com/jech/cert v0.0.0-20210819231831-aca735647728/go.mod h1:FXUA/zpiQfV4uBVN2kAwkf3X7pU7l1l2ovS45CsSYZs= github.com/jech/samplebuilder v0.0.0-20210711185346-d34c6dd315fb h1:ctDbFqRHUmxTThbYcTP2cdOVOkECsvhNyQKqeEK7RQQ= github.com/jech/samplebuilder v0.0.0-20210711185346-d34c6dd315fb/go.mod h1:PXhvo7PKy8CVqirCgoNG2BIjwow2Zd6LwCTScabl584= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= diff --git a/webserver/certificate.go b/webserver/certificate.go deleted file mode 100644 index 12d972a..0000000 --- a/webserver/certificate.go +++ /dev/null @@ -1,145 +0,0 @@ -package webserver - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "errors" - "log" - "math/big" - "os" - "path/filepath" - "sync" - "sync/atomic" - "time" -) - -type certInfo struct { - certificate *tls.Certificate - keyTime time.Time - certTime time.Time -} - -// certMu protects writing to certificate -var certMu sync.Mutex - -// certificate holds our current certificate, of type certInfo -var certificate atomic.Value - -// generateCertificate generates a self-signed certficate -func generateCertificate() (tls.Certificate, error) { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return tls.Certificate{}, err - } - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - now := time.Now() - - template := x509.Certificate{ - SerialNumber: serialNumber, - NotBefore: now, - NotAfter: now.Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - bytes, err := x509.CreateCertificate( - rand.Reader, &template, &template, &priv.PublicKey, priv, - ) - if err != nil { - return tls.Certificate{}, err - } - - return tls.Certificate{ - Certificate: [][]byte{bytes}, - PrivateKey: priv, - }, nil -} - -func modTime(filename string) time.Time { - fi, err := os.Stat(filename) - if err != nil { - if !os.IsNotExist(err) { - log.Printf("%v: %v", filename, err) - } - return time.Time{} - } - return fi.ModTime() -} - -// loadCertificate returns the current certificate if it is still valid. -func loadCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) *certInfo { - info, ok := certificate.Load().(*certInfo) - if !ok { - return nil - } - - if !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) { - return nil - } - - return info -} - -// storeCertificate returns the current certificate if it is still valid, -// and either reads or generates a new one otherwise. -func storeCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) (info *certInfo, err error) { - certMu.Lock() - defer certMu.Unlock() - - // the certificate may have been updated since we checked - info = loadCertificate(certFile, certTime, keyFile, keyTime) - if info != nil { - return - } - - var cert tls.Certificate - nocert := certTime.Equal(time.Time{}) - nokey := keyTime.Equal(time.Time{}) - - if nocert != nokey { - err = errors.New("only one of cert.pem and key.pem exists") - return - } else if nokey { - log.Printf("Generating self-signed certificate") - cert, err = generateCertificate() - if err != nil { - return - } - } else { - cert, err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return - } - } - info = &certInfo{ - certificate: &cert, - certTime: certTime, - keyTime: keyTime, - } - certificate.Store(info) - return -} - -func getCertificate(dataDir string) (*tls.Certificate, error) { - certFile := filepath.Join(dataDir, "cert.pem") - keyFile := filepath.Join(dataDir, "key.pem") - certTime := modTime(certFile) - keyTime := modTime(keyFile) - - info := loadCertificate(certFile, certTime, keyFile, keyTime) - - if info == nil { - var err error - info, err = storeCertificate( - certFile, certTime, keyFile, keyTime, - ) - if info == nil || err != nil { - return nil, err - } - } - return info.certificate, nil -} diff --git a/webserver/certificate_test.go b/webserver/certificate_test.go deleted file mode 100644 index 0441caa..0000000 --- a/webserver/certificate_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package webserver - -import ( - "testing" -) - -func TestGenerateCertificate(t *testing.T) { - _, err := generateCertificate() - if err != nil { - t.Errorf("generateCertificate: %v", err) - } -} - -func BenchmarkGenerateCertificate(b *testing.B) { - for i := 0; i < b.N; i++ { - _, err := generateCertificate() - if err != nil { - b.Errorf("generateCertificate: %v", err) - } - } -} - -func TestGetCertificate(t *testing.T) { - cert1, err := getCertificate("/tmp/no/such/file") - if err != nil { - t.Errorf("getCertificate: %v", err) - } - - cert2, err := getCertificate("/tmp/no/such/file") - if err != nil { - t.Errorf("getCertificate: %v", err) - } - - if cert1 != cert2 { - t.Errorf("cert1 != cert2") - } -} - -func BenchmarkGetCertificate(b *testing.B) { - _, err := getCertificate("/tmp/no/such/file") - if err != nil { - b.Errorf("getCertificate: %v", err) - } - b.StartTimer() - for i := 0; i < b.N; i++ { - _, err := getCertificate("/tmp/no/such/file") - if err != nil { - b.Errorf("getCertificate: %v", err) - } - } -} diff --git a/webserver/webserver.go b/webserver/webserver.go index b579893..132d757 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -22,6 +22,7 @@ import ( "github.com/gorilla/websocket" + "github.com/jech/cert" "github.com/jech/galene/diskwriter" "github.com/jech/galene/group" "github.com/jech/galene/rtpconn" @@ -58,9 +59,13 @@ func Serve(address string, dataDir string) error { IdleTimeout: 120 * time.Second, } if !Insecure { + certificate := cert.New( + filepath.Join(dataDir, "cert.pem"), + filepath.Join(dataDir, "key.pem"), + ) s.TLSConfig = &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return getCertificate(dataDir) + return certificate.Get() }, } }