Skip to content

Commit

Permalink
add memory cert reloader (#517)
Browse files Browse the repository at this point in the history
Co-authored-by: Po-Yao Chen <[email protected]>
  • Loading branch information
py4chen and Po-Yao Chen authored Nov 8, 2024
1 parent 8ff1340 commit a6d065a
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/scorecards.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright 2024 Yahoo Inc.
# Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

name: Scorecards supply-chain security
on:
# Only the default branch is supported.
Expand Down
2 changes: 2 additions & 0 deletions license_comment
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Copyright 2024 Yahoo Inc.
Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.
167 changes: 167 additions & 0 deletions utils/cert_reload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2024 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils

import (
"bytes"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
"sync"
"time"
)

const defaultMemPollInterval = 60 * time.Minute

// MemCertReloader reloads the (key, cert) pair by invoking the callback functions
// getter.
type MemCertReloader struct {
mu sync.RWMutex
getter func() ([]byte, []byte, error)
cert *tls.Certificate

logger func(fmt string, args ...interface{})
once sync.Once
stop chan struct{}
pollInterval time.Duration
}

// GetCertificate returns the latest known certificate and can be assigned to the
// GetCertificate member of the TLS config. For http.server use.
func (w *MemCertReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.GetLatestCertificate()
}

// GetClientCertificate returns the latest known certificate and can be assigned to the
// GetClientCertificate member of the TLS config. For http.client use.
func (w *MemCertReloader) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return w.GetLatestCertificate()
}

// GetLatestCertificate returns the latest known certificate.
func (w *MemCertReloader) GetLatestCertificate() (*tls.Certificate, error) {
w.mu.RLock()
c := w.cert
w.mu.RUnlock()
return c, nil
}

// Close stops the background refresh.
func (w *MemCertReloader) Close() error {
w.once.Do(func() {
close(w.stop)
})
return nil
}

// Reload reloads the certificate into the memory cache when the certificate is updated and valid.
func (w *MemCertReloader) Reload() error {
cb, kb, err := w.getter()
if err != nil {
return fmt.Errorf("failed to get the certificate and private key, %v", err)
}

if err := ValidateCertExpiry(cb, time.Now()); err != nil {
return fmt.Errorf("failed to validate certicate, %v", err)
}

cert, err := tls.X509KeyPair(cb, kb)
if err != nil {
return fmt.Errorf("failed to parse the certificate and private key, %v", err)
}

if w.cert != nil {
if subtle.ConstantTimeCompare(cert.Certificate[0], w.cert.Certificate[0]) == 1 {
return nil
}
}

w.mu.Lock()
w.cert = &cert
w.mu.Unlock()
w.logger("certs reloaded at %v", time.Now())
return nil
}

func (w *MemCertReloader) pollRefresh() {
poll := time.NewTicker(w.pollInterval)
defer poll.Stop()
for {
select {
case <-poll.C:
if err := w.Reload(); err != nil {
w.logger("cert reload error: %v\n", err)
}
case <-w.stop:
return
}
}
}

// CertReloadConfig contains the config for cert reload.
type CertReloadConfig struct {
// CertKeyGetter gets the certificate and the private key.
CertKeyGetter func() ([]byte, []byte, error)
Logger func(fmt string, args ...interface{})
PollInterval time.Duration
}

// NewCertReloader returns a MemCertReloader that reloads the (key, cert) pair whenever
// the cert file changes on the filesystem.
func NewCertReloader(config CertReloadConfig) (*MemCertReloader, error) {
if config.Logger == nil {
config.Logger = log.Printf
}
if config.PollInterval == 0 {
config.PollInterval = defaultMemPollInterval
}

var getter func() (cert []byte, key []byte, _ error)

if config.CertKeyGetter == nil {
return nil, errors.New("no getter function found in the config")
}

if config.CertKeyGetter != nil {
getter = config.CertKeyGetter
}

r := &MemCertReloader{
getter: getter,
logger: config.Logger,
pollInterval: config.PollInterval,
stop: make(chan struct{}, 10),
}
// load once to ensure cert is good.
if err := r.Reload(); err != nil {
return nil, err
}
go r.pollRefresh()
return r, nil
}

// ValidateCertExpiry validates the certificate expiry.
func ValidateCertExpiry(certPEM []byte, now time.Time) error {
if len(bytes.TrimSpace(certPEM)) == 0 {
return errors.New("certificate is empty")
}
for {
der, rest := pem.Decode(certPEM)
cp, err := x509.ParseCertificate(der.Bytes)
if err != nil {
return err
}
if now.Before(cp.NotBefore) || now.After(cp.NotAfter) {
return fmt.Errorf("invalid certificate, NotBefore: %v, NotAfter: %v, Now: %v", cp.NotBefore, cp.NotAfter, now)
}
if len(bytes.TrimSpace(rest)) == 0 {
return nil
}
certPEM = rest
}
}
155 changes: 155 additions & 0 deletions utils/cert_reload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils

import (
"crypto/tls"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestMemCertReloader_Reload(t *testing.T) {
t.Parallel()
type expect struct {
cert *tls.Certificate
wantErr assert.ErrorAssertionFunc
}

tests := []struct {
name string
setup func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect)
certPath string
keyPath string
wantCert *tls.Certificate
wantErr assert.ErrorAssertionFunc
}{
{
name: "happy path",
certPath: "testdata/client.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

reloader, err := NewCertReloader(
CertReloadConfig{
CertKeyGetter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
},
)
if err != nil {
t.Fatal(err)
}
wantCrt, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
want := &expect{
cert: &wantCrt,
wantErr: assert.NoError,
}
return reloader, want
},
},
{
name: "getter error",
certPath: "testdata/invalid.crt",
keyPath: "testdata/invalid.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
reloader := &MemCertReloader{
getter: func() ([]byte, []byte, error) {
return nil, nil, fmt.Errorf("get error")
},
}
want := &expect{
wantErr: assert.Error,
}
return reloader, want
},
},
{
name: "unchanged cert",
certPath: "testdata/client.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}

reloader, err := NewCertReloader(
CertReloadConfig{
CertKeyGetter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
},
)
if err != nil {
t.Fatal(err)
}
wantCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
reloader.cert = &wantCert
want := &expect{
cert: &wantCert,
wantErr: assert.NoError,
}
return reloader, want
},
},
{
name: "invalid key pair",
certPath: "testdata/ca.crt",
keyPath: "testdata/client.key",
setup: func(t *testing.T, certPath, keyPath string) (*MemCertReloader, *expect) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
t.Fatal(err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatal(err)
}
reloader := &MemCertReloader{
getter: func() ([]byte, []byte, error) {
return certPEM, keyPEM, nil
},
}
if err != nil {
t.Fatal(err)
}
want := &expect{
wantErr: assert.Error,
}
return reloader, want
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reloader, want := tt.setup(t, tt.certPath, tt.keyPath)
gotErr := reloader.Reload()
if !want.wantErr(t, gotErr, "unexpected error") {
return
}
assert.Equal(t, reloader.cert, want.cert, "unexpected result")
})
}
}
11 changes: 11 additions & 0 deletions utils/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright 2022 Yahoo Inc.
// Licensed under the terms of the Apache License 2.0. Please see LICENSE file in project root for terms.

package utils

//go:generate certstrap init --passphrase "" --common-name "ca" --years 80
//go:generate certstrap request-cert --passphrase "" --common-name client
//go:generate certstrap sign client --passphrase "" --CA ca --years 80
//go:generate mkdir -p ./testdata
//go:generate mv -f ./out/ca.crt ./out/client.crt ./out/client.key ./testdata
//go:generate rm -rf ./out
28 changes: 28 additions & 0 deletions utils/testdata/ca.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-----BEGIN CERTIFICATE-----
MIIE3DCCAsSgAwIBAgIBATANBgkqhkiG9w0BAQsFADANMQswCQYDVQQDEwJjYTAg
Fw0yNDExMDcxNzQ4NTFaGA8yMTA2MDUwNzE3NTg1MVowDTELMAkGA1UEAxMCY2Ew
ggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDA3OVLRgACNl6b1IDiAq0c
pL5FncGJC/w5/01LUkgy+9rAk0lwwjnZiXf3aMOC2Q6267uB9BtpaoTLR5h8GRM7
25GNCps2x8gbo8GaBqN8UfH/Cn+yGg652tZI4ikD5HE1rIGYAnhll3esEV+zCiFr
Nuh+RFyabcprWr0FQ/N6ysrrMdFQNo17WEIp0L3nevznLU1d7uc7h6z2lKU2DBrT
ghFKvwSO724YHhQsvCZOtNcIPsYcwTEHiugLEhZrcYQ2OjgiygCmg71OiPgHoATa
lrUGnv7tibyjvQ/XIZqRu3iL3GJAJJV3S6owHl8eSur0u8RW3mHneEZKqZqF4fXY
isSzmO4SQDJibtiWboQZP2NmkEUR7ar7Y9z6SgyFDie5GH5kvP7g3eHcIs/6sz+s
yM0DY5FO9YsNBuQbfxfEtXSbQ8Y2ZC+0NWgifYCG0DAmaoyRSZjNjBMBm3naNJrQ
V6TPtsANJxhp4b8nTa9W1Fh04w8yH6ROTPYb9LWyYuWuoV6BCE6rj5mmLWRqoT2n
FXxiqAiftg3qrco+ZCh7KZ8ht4+dlxeDC3ki2jpCx33GZZSEQhXX3+ZcizVK24d6
BOwcbn6NKda+xW/7xaK9dLkYHeQBSxXE+U1X7RmsVfEQr/B7vigwWWzefOG+ECOE
fq6Qq9RaPcTvuFdXuVi8tQIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0T
AQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU/wf4Qp+YW2ZZOdUXkbzT6gluGNkwDQYJ
KoZIhvcNAQELBQADggIBAAg/nnhO2YaIrLs89BZAtTwe0UCJVaSx9kt5wTRjsAwV
A149MuGstrootq31mt4k90a7t3X63tUzvAOpDV+/JZAei3ASz2KN13C4BKXToaBW
IcMTNXYwZiL2dF1uqkMxPVo0w+NRyQtVwdYDdqRCqHKY7TFz7wJPjDrYetm5OwfQ
JritO2h5mqxr+ubg2mWECxhHT9D4N6w0dhUCJX/zbH7QF8mvuEarlQB3ct+Ew868
DoWbvWD3pDRqD8Fjt1CraXm1FWhR84uPLga8XpOQ+NvJWnQFXWLFaqAtBs3a8eaj
nmptRl0Iue/esTSQBRpeqzu73dzCtkeFrR2Nst1Ycpmbl7cFat1m8tvBwAM7Pbku
0Bom1qT6daOZrvbIDXKAlaBAseT6o892PswWUjRSC7ZqhrUHMQTq1oJs4lxbtQTQ
pnOuVwQLWOv+vlaoCufnysP65zxHAvzMt25L7/yyTmp4f+eixn7YReQg+4px9DeE
2loWjq4YTbEcXCzgJ8HR4uhppHKZXJhB/vx7Qg386zgtRtpa/QeGAtfeUFyCHlmu
P35g9wKaonBtN9DQyFN9sBJ9ugLGZ/YeXwCkPzT3OoylNRM+rr2h1E80fiF76jOz
Gx5Uv7/b5ie5/917MPaJfdk99AZ/VOb5m4HBNqB4O8CmgTdbYJjDJAuK8YBz5rxO
-----END CERTIFICATE-----
24 changes: 24 additions & 0 deletions utils/testdata/client.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-----BEGIN CERTIFICATE-----
MIIEGzCCAgOgAwIBAgIQK2zwL4S3fDViQnDt0QMyqzANBgkqhkiG9w0BAQsFADAN
MQswCQYDVQQDEwJjYTAgFw0yNDExMDcxNzQ4NTFaGA8yMTA0MTEwNzE3NTg1MVow
ETEPMA0GA1UEAxMGY2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
AQEAxKRuswos/E15VoODmFJnDKkmgwB4pfdOb1GBi8r/+stRWldkX72DkOYQVJ/T
qshJocMelEItDI/HwP3kQ3tHTjVklp8ekrqxkFxh3lpxAKTMKzwbPFsLpdBm+jB5
E5Rz7DA9d9mWVwxzrwOmRM9FWwiRM8NqjELSvigyf+Q3ZuZsgFIxHpO5vQ9h8wHr
J4Xx4MJAAltNe8wD5GGFoZ3S+gJaEOqilPl5RXFu6jUbj6tmFTEKowCoWMeBi9ZQ
EHavkTbOEdSqVSe4mkvM3Hznn//wgnx1Wxd8BYEpeAYYOAqTg5A0hdNOrv67smGc
sJpb6TNNEqzcJpvdHwA38FyUlQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMCA7gwHQYD
VR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSn7pDoI+w7gybp
E6Zt+78Sl0uR1DAfBgNVHSMEGDAWgBT/B/hCn5hbZlk51ReRvNPqCW4Y2TANBgkq
hkiG9w0BAQsFAAOCAgEARndBgdqO5f4M9vTcbhBfU1qN2CB7TdKqMr1jOsfHMHaS
hu00sw2Y0MKPSsSkV3FndmJUy5sjYHbxEAmjsUVnsaW/1hXCgvnHMl0JbOVdtkiW
qrIMKudTKo49hEk3jl32dmgR0EWj8PRF8blgl7j3SdmixYIpuoJ6zxccET5SxvSc
c2Srl3QP16pBc+OnaHZoEhiUOLRogP+Gn+4daH1iPpTIC5TGvpz9aK2iexoH1wJa
sj6JykZpsfT7pCv5wl2JhNtAKSjEAhRz2gv6Md3lps/0PjG3/cEKxWKdZWtZg0Jk
5iSbAVzi2E7xcfNM3Gmp4f7xiAWN70HH0c/HlcJ/jjlQ9/pt8BBpEItlpW2oGQKc
EtkGvoBWrfPq6WRHhFSAamIL3aHCsqXa2y9CtQ6Wk4eXBIMJ1uU+zfCTEm1DMFJh
JHNENPc8eYwEPloAQDQbwkTRKDKP+FjyhRiMi071X0oVw5byIzRZegQM5We9XV2k
PGGxyqdYxQ2Xv/DHKQpiEnmDjQ5j8xHUPJaT81Do3x8L1oBXsUvPIAQZfV7tgIKk
xl89dyPLYCjAQc9bnAp+YiqS2CRTQDMVmPe+8SpRFLgGjc7YTwTKMdm+lZWj169/
56KGiU0neQozv92gKPygazzqZ/W/BBTXxsRtSwOHzdQ0SpONeyIdNKqyh+xuc6M=
-----END CERTIFICATE-----
Loading

0 comments on commit a6d065a

Please sign in to comment.