2021-02-24 20:01:48 +01:00
|
|
|
package webserver
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/rsa"
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/x509"
|
2021-02-24 22:23:38 +01:00
|
|
|
"errors"
|
2021-02-24 20:01:48 +01:00
|
|
|
"log"
|
|
|
|
"math/big"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2021-02-26 12:38:18 +01:00
|
|
|
"sync"
|
2021-02-24 20:01:48 +01:00
|
|
|
"sync/atomic"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
type certInfo struct {
|
|
|
|
certificate *tls.Certificate
|
|
|
|
keyTime time.Time
|
|
|
|
certTime time.Time
|
|
|
|
}
|
|
|
|
|
2021-02-26 12:38:18 +01:00
|
|
|
// certMu protects writing to certificate
|
|
|
|
var certMu sync.Mutex
|
|
|
|
|
|
|
|
// certificate holds our current certificate, of type certInfo
|
2021-02-24 20:01:48 +01:00
|
|
|
var certificate atomic.Value
|
|
|
|
|
2021-02-26 12:38:18 +01:00
|
|
|
// generateCertificate generates a self-signed certficate
|
|
|
|
func generateCertificate() (tls.Certificate, error) {
|
2021-02-24 20:01:48 +01:00
|
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
if err != nil {
|
|
|
|
return tls.Certificate{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
|
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
|
|
|
now := time.Now()
|
|
|
|
|
|
|
|
template := x509.Certificate{
|
|
|
|
SerialNumber: serialNumber,
|
|
|
|
NotBefore: now,
|
|
|
|
NotAfter: now.Add(365 * 24 * time.Hour),
|
|
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
|
|
BasicConstraintsValid: true,
|
|
|
|
}
|
|
|
|
bytes, err := x509.CreateCertificate(
|
|
|
|
rand.Reader, &template, &template, &priv.PublicKey, priv,
|
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
return tls.Certificate{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return tls.Certificate{
|
|
|
|
Certificate: [][]byte{bytes},
|
|
|
|
PrivateKey: priv,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
2021-02-26 12:38:18 +01:00
|
|
|
func modTime(filename string) time.Time {
|
2021-02-24 20:01:48 +01:00
|
|
|
fi, err := os.Stat(filename)
|
|
|
|
if err != nil {
|
|
|
|
if !os.IsNotExist(err) {
|
|
|
|
log.Printf("%v: %v", filename, err)
|
|
|
|
}
|
|
|
|
return time.Time{}
|
|
|
|
}
|
|
|
|
return fi.ModTime()
|
|
|
|
}
|
|
|
|
|
2021-02-26 12:38:18 +01:00
|
|
|
// loadCertificate returns the current certificate if it is still valid.
|
|
|
|
func loadCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) *certInfo {
|
2021-02-24 20:01:48 +01:00
|
|
|
info, ok := certificate.Load().(*certInfo)
|
2021-02-26 12:38:18 +01:00
|
|
|
if !ok {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if !info.certTime.Equal(certTime) || !info.keyTime.Equal(keyTime) {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return info
|
|
|
|
}
|
|
|
|
|
|
|
|
// storeCertificate returns the current certificate if it is still valid,
|
|
|
|
// and either reads or generates a new one otherwise.
|
|
|
|
func storeCertificate(certFile string, certTime time.Time, keyFile string, keyTime time.Time) (info *certInfo, err error) {
|
|
|
|
certMu.Lock()
|
|
|
|
defer certMu.Unlock()
|
|
|
|
|
|
|
|
// the certificate may have been updated since we checked
|
|
|
|
info = loadCertificate(certFile, certTime, keyFile, keyTime)
|
|
|
|
if info != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
var cert tls.Certificate
|
|
|
|
nocert := certTime.Equal(time.Time{})
|
|
|
|
nokey := keyTime.Equal(time.Time{})
|
2021-02-24 20:01:48 +01:00
|
|
|
|
2021-02-26 12:38:18 +01:00
|
|
|
if nocert != nokey {
|
|
|
|
err = errors.New("only one of cert.pem and key.pem exists")
|
|
|
|
return
|
|
|
|
} else if nokey {
|
|
|
|
log.Printf("Generating self-signed certificate")
|
|
|
|
cert, err = generateCertificate()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
info = &certInfo{
|
|
|
|
certificate: &cert,
|
|
|
|
certTime: certTime,
|
|
|
|
keyTime: keyTime,
|
|
|
|
}
|
|
|
|
certificate.Store(info)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func getCertificate(dataDir string) (*tls.Certificate, error) {
|
2021-02-24 20:01:48 +01:00
|
|
|
certFile := filepath.Join(dataDir, "cert.pem")
|
|
|
|
keyFile := filepath.Join(dataDir, "key.pem")
|
2021-02-26 12:38:18 +01:00
|
|
|
certTime := modTime(certFile)
|
|
|
|
keyTime := modTime(keyFile)
|
|
|
|
|
|
|
|
info := loadCertificate(certFile, certTime, keyFile, keyTime)
|
|
|
|
|
|
|
|
if info == nil {
|
|
|
|
var err error
|
|
|
|
info, err = storeCertificate(
|
|
|
|
certFile, certTime, keyFile, keyTime,
|
|
|
|
)
|
|
|
|
if info == nil || err != nil {
|
|
|
|
return nil, err
|
2021-02-24 20:01:48 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return info.certificate, nil
|
|
|
|
}
|