Skip to content

Commit

Permalink
Allow forcing cert reissuance
Browse files Browse the repository at this point in the history
Refreshing the cert should force reissuance as opposed to returning
early if the SANs aren't changing. This is currently breaking refresh
of expired certs as per:
k3s-io/k3s#1621 (comment)

Signed-off-by: Brad Davidson <[email protected]>
  • Loading branch information
brandond committed Aug 6, 2020
1 parent ce09000 commit d834fdf
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 57 deletions.
20 changes: 15 additions & 5 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ const (
duration365d = time.Hour * 24 * 365
)

var (
ErrStaticCert = errors.New("cannot renew static certificate")
)

// Config contains the basic fields required for creating a certificate
type Config struct {
CommonName string
Expand Down Expand Up @@ -119,7 +123,13 @@ func NewSignedCert(cfg Config, key crypto.Signer, caCert *x509.Certificate, caKe
if err != nil {
return nil, err
}
return x509.ParseCertificate(certDERBytes)

parsedCert, err := x509.ParseCertificate(certDERBytes)
if err == nil {
logrus.Infof("certificate %v signed by %v: notBefore=%v notAfter=%v",
parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter)
}
return parsedCert, err
}

// MakeEllipticPrivateKeyPEM creates an ECDSA private key
Expand Down Expand Up @@ -271,11 +281,11 @@ func ipsToStrings(ips []net.IP) []string {
}

// IsCertExpired checks if the certificate about to expire
func IsCertExpired(cert *x509.Certificate) bool {
func IsCertExpired(cert *x509.Certificate, days int) bool {
expirationDate := cert.NotAfter
diffDays := expirationDate.Sub(time.Now()).Hours() / 24.0
if diffDays <= 90 {
logrus.Infof("certificate will expire in %f days", diffDays)
diffDays := time.Until(expirationDate).Hours() / 24.0
if diffDays <= float64(days) {
logrus.Infof("certificate %v will expire in %f days at %v", cert.Subject, diffDays, cert.NotAfter)
return true
}
return false
Expand Down
6 changes: 3 additions & 3 deletions cert/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ func CanReadCertAndKey(certPath, keyPath string) (bool, error) {
certReadable := canReadFile(certPath)
keyReadable := canReadFile(keyPath)

if certReadable == false && keyReadable == false {
if !certReadable && !keyReadable {
return false, nil
}

if certReadable == false {
if !certReadable {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", certPath)
}

if keyReadable == false {
if !keyReadable {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", keyPath)
}

Expand Down
9 changes: 8 additions & 1 deletion factory/cert_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"math/big"
"net"
"time"

"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -92,7 +94,12 @@ func NewSignedCert(signer crypto.Signer, caCert *x509.Certificate, caKey crypto.
return nil, err
}

return x509.ParseCertificate(cert)
parsedCert, err := x509.ParseCertificate(cert)
if err == nil {
logrus.Infof("certificate %v signed by %v: notBefore=%v notAfter=%v",
parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter)
}
return parsedCert, err
}

func ParseCertPEM(pemCerts []byte) (*x509.Certificate, error) {
Expand Down
53 changes: 31 additions & 22 deletions factory/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/sha1"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"net"
"regexp"
"sort"
Expand Down Expand Up @@ -49,16 +49,14 @@ func cns(secret *v1.Secret) (cns []string) {
return
}

func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) {
func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, err error) {
var (
cns = cns(secret)
digest = sha256.New()
cns = cns(secret)
)

sort.Strings(cns)

for _, v := range cns {
digest.Write([]byte(v))
ip := net.ParseIP(v)
if ip == nil {
domains = append(domains, v)
Expand All @@ -67,40 +65,51 @@ func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string,
}
}

hash = hex.EncodeToString(digest.Sum(nil))
return
}

func (t *TLS) Merge(target, additional *v1.Secret) (*v1.Secret, bool, error) {
return t.AddCN(target, cns(additional)...)
secret, updated, err := t.AddCN(target, cns(additional)...)
// AddCN returns early if the CNs are the same, but we also need to handle the case
// where the secret has been renewed with the same CNs. Since the kubernetes storage backend
// uses Merge to detect changes, we return the second secret and note that it has been updated.
if !updated {
if target.Annotations[hashKey] != additional.Annotations[hashKey] {
secret = additional
updated = true
}
}
return secret, updated, err
}

func (t *TLS) Refresh(secret *v1.Secret) (*v1.Secret, error) {
func (t *TLS) Renew(secret *v1.Secret) (*v1.Secret, error) {
if IsStatic(secret) {
return secret, cert.ErrStaticCert
}
cns := cns(secret)
secret = secret.DeepCopy()
secret.Annotations = map[string]string{}
secret, _, err := t.AddCN(secret, cns...)
secret, _, err := t.generateCert(secret, cns...)
return secret, err
}

func (t *TLS) Filter(cn ...string) []string {
if t.FilterCN == nil {
if len(cn) == 0 || t.FilterCN == nil {
return cn
}
return t.FilterCN(cn...)
}

func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
var (
err error
)

cn = t.Filter(cn...)

if !NeedsUpdate(0, secret, cn...) {
if IsStatic(secret) || !NeedsUpdate(0, secret, cn...) {
return secret, false, nil
}
return t.generateCert(secret, cn...)
}

func (t *TLS) generateCert(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
secret = secret.DeepCopy()
if secret == nil {
secret = &v1.Secret{}
Expand All @@ -113,7 +122,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
return nil, false, err
}

domains, ips, hash, err := collectCNs(secret)
domains, ips, err := collectCNs(secret)
if err != nil {
return nil, false, err
}
Expand All @@ -133,7 +142,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
}
secret.Data[v1.TLSCertKey] = certBytes
secret.Data[v1.TLSPrivateKeyKey] = keyBytes
secret.Annotations[hashKey] = hash
secret.Annotations[hashKey] = fmt.Sprintf("SHA1=%X", sha1.Sum(newCert.Raw))

return secret, true, nil
}
Expand All @@ -157,15 +166,15 @@ func populateCN(secret *v1.Secret, cn ...string) *v1.Secret {
return secret
}

func IsStatic(secret *v1.Secret) bool {
return secret.Annotations[Static] == "true"
}

func NeedsUpdate(maxSANs int, secret *v1.Secret, cn ...string) bool {
if secret == nil {
return true
}

if secret.Annotations[Static] == "true" {
return false
}

for _, cn := range cn {
if secret.Annotations[cnPrefix+cn] == "" {
if maxSANs > 0 && len(cns(secret)) >= maxSANs {
Expand Down
41 changes: 24 additions & 17 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sync"
"time"

"github.com/rancher/dynamiclistener/cert"
"github.com/rancher/dynamiclistener/factory"
"github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
Expand All @@ -22,7 +23,7 @@ type TLSStorage interface {
}

type TLSFactory interface {
Refresh(secret *v1.Secret) (*v1.Secret, error)
Renew(secret *v1.Secret) (*v1.Secret, error)
AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error)
Merge(target *v1.Secret, additional *v1.Secret) (*v1.Secret, bool, error)
Filter(cn ...string) []string
Expand Down Expand Up @@ -152,13 +153,13 @@ type listener struct {
func (l *listener) WrapExpiration(days int) net.Listener {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(5 * time.Minute)
time.Sleep(30 * time.Second)

for {
wait := 6 * time.Hour
if err := l.checkExpiration(days); err != nil {
if err := l.checkExpiration(days); err != nil && err != cert.ErrStaticCert {
logrus.Errorf("failed to check and refresh dynamic cert: %v", err)
wait = 5 + time.Minute
wait = 30 * time.Second
}
select {
case <-ctx.Done():
Expand Down Expand Up @@ -191,22 +192,26 @@ func (l *listener) checkExpiration(days int) error {
return err
}

cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
keyPair, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
if err != nil {
return err
}

certParsed, err := x509.ParseCertificate(cert.Certificate[0])
certParsed, err := x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return err
}

if time.Now().UTC().Add(time.Hour * 24 * time.Duration(days)).After(certParsed.NotAfter) {
secret, err := l.factory.Refresh(secret)
if cert.IsCertExpired(certParsed, days) {
secret, err := l.factory.Renew(secret)
if err != nil {
return err
}
return l.storage.Update(secret)
if err := l.storage.Update(secret); err != nil {
return err
}
// clear version to force cert reload
l.version = ""
}

return nil
Expand Down Expand Up @@ -304,7 +309,7 @@ func (l *listener) updateCert(cn ...string) error {
return err
}

if !factory.NeedsUpdate(l.maxSANs, secret, cn...) {
if !factory.IsStatic(secret) && !factory.NeedsUpdate(l.maxSANs, secret, cn...) {
return nil
}

Expand All @@ -324,13 +329,6 @@ func (l *listener) updateCert(cn ...string) error {
}
// clear version to force cert reload
l.version = ""
if l.conns != nil {
l.connLock.Lock()
for _, conn := range l.conns {
_ = conn.close()
}
l.connLock.Unlock()
}
}

return nil
Expand Down Expand Up @@ -366,6 +364,15 @@ func (l *listener) loadCert() (*tls.Certificate, error) {
return nil, err
}

// cert has changed, close closeWrapper wrapped connections
if l.conns != nil {
l.connLock.Lock()
for _, conn := range l.conns {
_ = conn.close()
}
l.connLock.Unlock()
}

l.cert = &cert
l.version = secret.ResourceVersion
return l.cert, nil
Expand Down
5 changes: 2 additions & 3 deletions storage/kubernetes/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ func (s *storage) saveInK8s(secret *v1.Secret) (*v1.Secret, error) {
if targetSecret.UID == "" {
logrus.Infof("Creating new TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Create(targetSecret)
} else {
logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Update(targetSecret)
}
logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Update(targetSecret)
}

func (s *storage) Update(secret *v1.Secret) (err error) {
Expand Down
14 changes: 8 additions & 6 deletions storage/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ func (m *memory) Get() (*v1.Secret, error) {
}

func (m *memory) Update(secret *v1.Secret) error {
if m.storage != nil {
if err := m.storage.Update(secret); err != nil {
return err
if m.secret == nil || m.secret.ResourceVersion != secret.ResourceVersion {
if m.storage != nil {
if err := m.storage.Update(secret); err != nil {
return err
}
}
}

logrus.Infof("Active TLS secret %s (ver=%s) (count %d): %v", secret.Name, secret.ResourceVersion, len(secret.Annotations)-1, secret.Annotations)
m.secret = secret
logrus.Infof("Active TLS secret %s (ver=%s) (count %d): %v", secret.Name, secret.ResourceVersion, len(secret.Annotations)-1, secret.Annotations)
m.secret = secret
}
return nil
}

0 comments on commit d834fdf

Please sign in to comment.