1
Fork 0

Protect against simultaneous generation of certificates.

This commit is contained in:
Juliusz Chroboczek 2021-02-26 12:38:18 +01:00
parent c19b356e54
commit b3727824b3
1 changed files with 72 additions and 30 deletions

View File

@ -10,6 +10,7 @@ import (
"math/big" "math/big"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -20,9 +21,14 @@ type certInfo struct {
certTime time.Time certTime time.Time
} }
// certMu protects writing to certificate
var certMu sync.Mutex
// certificate holds our current certificate, of type certInfo
var certificate atomic.Value var certificate atomic.Value
func generateCertificate(dataDir string) (tls.Certificate, error) { // generateCertificate generates a self-signed certficate
func generateCertificate() (tls.Certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048) priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
@ -53,7 +59,7 @@ func generateCertificate(dataDir string) (tls.Certificate, error) {
}, nil }, nil
} }
func fileTime(filename string) time.Time { func modTime(filename string) time.Time {
fi, err := os.Stat(filename) fi, err := os.Stat(filename)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
@ -64,32 +70,49 @@ func fileTime(filename string) time.Time {
return fi.ModTime() return fi.ModTime()
} }
func getCertificate(dataDir string) (*tls.Certificate, error) { // loadCertificate returns the current certificate if it is still valid.
func loadCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) *certInfo {
info, ok := certificate.Load().(*certInfo) info, ok := certificate.Load().(*certInfo)
if !ok {
return nil
}
certFile := filepath.Join(dataDir, "cert.pem") if !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) {
keyFile := filepath.Join(dataDir, "key.pem") return nil
certTime := fileTime(certFile) }
keyTime := fileTime(keyFile)
return info
}
// storeCertificate returns the current certificate if it is still valid,
// and either reads or generates a new one otherwise.
func storeCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) (info *certInfo, err error) {
certMu.Lock()
defer certMu.Unlock()
// the certificate may have been updated since we checked
info = loadCertificate(certFile, certTime, keyFile, keyTime)
if info != nil {
return
}
if !ok || !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) {
var cert tls.Certificate var cert tls.Certificate
nocert := certTime.Equal(time.Time{}) nocert := certTime.Equal(time.Time{})
nokey := keyTime.Equal(time.Time{}) nokey := keyTime.Equal(time.Time{})
if nocert != nokey { if nocert != nokey {
return nil, errors.New("only one of cert.pem and key.pem exists") err = errors.New("only one of cert.pem and key.pem exists")
return
} else if nokey { } else if nokey {
log.Printf("Generating self-signed certificate") log.Printf("Generating self-signed certificate")
var err error cert, err = generateCertificate()
cert, err = generateCertificate(dataDir)
if err != nil { if err != nil {
return nil, err return
} }
} else { } else {
var err error
cert, err = tls.LoadX509KeyPair(certFile, keyFile) cert, err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return nil, err return
} }
} }
info = &certInfo{ info = &certInfo{
@ -98,6 +121,25 @@ func getCertificate(dataDir string) (*tls.Certificate, error) {
keyTime: keyTime, keyTime: keyTime,
} }
certificate.Store(info) certificate.Store(info)
return
}
func getCertificate(dataDir string) (*tls.Certificate, error) {
certFile := filepath.Join(dataDir, "cert.pem")
keyFile := filepath.Join(dataDir, "key.pem")
certTime := modTime(certFile)
keyTime := modTime(keyFile)
info := loadCertificate(certFile, certTime, keyFile, keyTime)
if info == nil {
var err error
info, err = storeCertificate(
certFile, certTime, keyFile, keyTime,
)
if info == nil || err != nil {
return nil, err
}
} }
return info.certificate, nil return info.certificate, nil
} }