mirror of
https://github.com/jech/galene.git
synced 2024-12-22 15:25:48 +01:00
Protect against simultaneous generation of certificates.
This commit is contained in:
parent
c19b356e54
commit
b3727824b3
1 changed files with 72 additions and 30 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
@ -20,9 +21,14 @@ type certInfo struct {
|
|||
certTime time.Time
|
||||
}
|
||||
|
||||
// certMu protects writing to certificate
|
||||
var certMu sync.Mutex
|
||||
|
||||
// certificate holds our current certificate, of type certInfo
|
||||
var certificate atomic.Value
|
||||
|
||||
func generateCertificate(dataDir string) (tls.Certificate, error) {
|
||||
// generateCertificate generates a self-signed certficate
|
||||
func generateCertificate() (tls.Certificate, error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
|
@ -53,7 +59,7 @@ func generateCertificate(dataDir string) (tls.Certificate, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func fileTime(filename string) time.Time {
|
||||
func modTime(filename string) time.Time {
|
||||
fi, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
|
@ -64,40 +70,76 @@ func fileTime(filename string) time.Time {
|
|||
return fi.ModTime()
|
||||
}
|
||||
|
||||
func getCertificate(dataDir string) (*tls.Certificate, error) {
|
||||
// loadCertificate returns the current certificate if it is still valid.
|
||||
func loadCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) *certInfo {
|
||||
info, ok := certificate.Load().(*certInfo)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// storeCertificate returns the current certificate if it is still valid,
|
||||
// and either reads or generates a new one otherwise.
|
||||
func storeCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) (info *certInfo, err error) {
|
||||
certMu.Lock()
|
||||
defer certMu.Unlock()
|
||||
|
||||
// the certificate may have been updated since we checked
|
||||
info = loadCertificate(certFile, certTime, keyFile, keyTime)
|
||||
if info != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cert tls.Certificate
|
||||
nocert := certTime.Equal(time.Time{})
|
||||
nokey := keyTime.Equal(time.Time{})
|
||||
|
||||
if nocert != nokey {
|
||||
err = errors.New("only one of cert.pem and key.pem exists")
|
||||
return
|
||||
} else if nokey {
|
||||
log.Printf("Generating self-signed certificate")
|
||||
cert, err = generateCertificate()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
info = &certInfo{
|
||||
certificate: &cert,
|
||||
certTime: certTime,
|
||||
keyTime: keyTime,
|
||||
}
|
||||
certificate.Store(info)
|
||||
return
|
||||
}
|
||||
|
||||
func getCertificate(dataDir string) (*tls.Certificate, error) {
|
||||
certFile := filepath.Join(dataDir, "cert.pem")
|
||||
keyFile := filepath.Join(dataDir, "key.pem")
|
||||
certTime := fileTime(certFile)
|
||||
keyTime := fileTime(keyFile)
|
||||
certTime := modTime(certFile)
|
||||
keyTime := modTime(keyFile)
|
||||
|
||||
if !ok || !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) {
|
||||
var cert tls.Certificate
|
||||
nocert := certTime.Equal(time.Time{})
|
||||
nokey := keyTime.Equal(time.Time{})
|
||||
if nocert != nokey {
|
||||
return nil, errors.New("only one of cert.pem and key.pem exists")
|
||||
} else if nokey {
|
||||
log.Printf("Generating self-signed certificate")
|
||||
var err error
|
||||
cert, err = generateCertificate(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := loadCertificate(certFile, certTime, keyFile, keyTime)
|
||||
|
||||
if info == nil {
|
||||
var err error
|
||||
info, err = storeCertificate(
|
||||
certFile, certTime, keyFile, keyTime,
|
||||
)
|
||||
if info == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info = &certInfo{
|
||||
certificate: &cert,
|
||||
certTime: certTime,
|
||||
keyTime: keyTime,
|
||||
}
|
||||
certificate.Store(info)
|
||||
}
|
||||
return info.certificate, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue