diff --git a/webserver/certificate.go b/webserver/certificate.go index a35c28e..12d972a 100644 --- a/webserver/certificate.go +++ b/webserver/certificate.go @@ -10,6 +10,7 @@ import ( "math/big" "os" "path/filepath" + "sync" "sync/atomic" "time" ) @@ -20,9 +21,14 @@ type certInfo struct { certTime time.Time } +// certMu protects writing to certificate +var certMu sync.Mutex + +// certificate holds our current certificate, of type certInfo var certificate atomic.Value -func generateCertificate(dataDir string) (tls.Certificate, error) { +// generateCertificate generates a self-signed certficate +func generateCertificate() (tls.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return tls.Certificate{}, err @@ -53,7 +59,7 @@ func generateCertificate(dataDir string) (tls.Certificate, error) { }, nil } -func fileTime(filename string) time.Time { +func modTime(filename string) time.Time { fi, err := os.Stat(filename) if err != nil { if !os.IsNotExist(err) { @@ -64,40 +70,76 @@ func fileTime(filename string) time.Time { return fi.ModTime() } -func getCertificate(dataDir string) (*tls.Certificate, error) { +// loadCertificate returns the current certificate if it is still valid. +func loadCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) *certInfo { info, ok := certificate.Load().(*certInfo) + if !ok { + return nil + } + if !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) { + return nil + } + + return info +} + +// storeCertificate returns the current certificate if it is still valid, +// and either reads or generates a new one otherwise. +func storeCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) (info *certInfo, err error) { + certMu.Lock() + defer certMu.Unlock() + + // the certificate may have been updated since we checked + info = loadCertificate(certFile, certTime, keyFile, keyTime) + if info != nil { + return + } + + var cert tls.Certificate + nocert := certTime.Equal(time.Time{}) + nokey := keyTime.Equal(time.Time{}) + + if nocert != nokey { + err = errors.New("only one of cert.pem and key.pem exists") + return + } else if nokey { + log.Printf("Generating self-signed certificate") + cert, err = generateCertificate() + if err != nil { + return + } + } else { + cert, err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + } + info = &certInfo{ + certificate: &cert, + certTime: certTime, + keyTime: keyTime, + } + certificate.Store(info) + return +} + +func getCertificate(dataDir string) (*tls.Certificate, error) { certFile := filepath.Join(dataDir, "cert.pem") keyFile := filepath.Join(dataDir, "key.pem") - certTime := fileTime(certFile) - keyTime := fileTime(keyFile) + certTime := modTime(certFile) + keyTime := modTime(keyFile) - if !ok || !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) { - var cert tls.Certificate - nocert := certTime.Equal(time.Time{}) - nokey := keyTime.Equal(time.Time{}) - if nocert != nokey { - return nil, errors.New("only one of cert.pem and key.pem exists") - } else if nokey { - log.Printf("Generating self-signed certificate") - var err error - cert, err = generateCertificate(dataDir) - if err != nil { - return nil, err - } - } else { - var err error - cert, err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } + info := loadCertificate(certFile, certTime, keyFile, keyTime) + + if info == nil { + var err error + info, err = storeCertificate( + certFile, certTime, keyFile, keyTime, + ) + if info == nil || err != nil { + return nil, err } - info = &certInfo{ - certificate: &cert, - certTime: certTime, - keyTime: keyTime, - } - certificate.Store(info) } return info.certificate, nil }