Skip to content

Commit

Permalink
fixup: handle err and lint
Browse files Browse the repository at this point in the history
Signed-off-by: Todd Baert <[email protected]>
  • Loading branch information
toddbaert committed Sep 19, 2024
1 parent 1674c78 commit 71943e7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
6 changes: 5 additions & 1 deletion core/pkg/telemetry/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ func buildTransportCredentials(_ context.Context, cfg CollectorConfig) (credenti
RootCAs: capool,
MinVersion: tls.VersionTLS13,
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
return reloader.GetCertificate()
certs, err := reloader.GetCertificate()
if err != nil {
return nil, fmt.Errorf("failed to reload certs: %w", err)
}
return certs, nil
},
}

Expand Down
3 changes: 2 additions & 1 deletion flagd/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ func init() {
flags.StringP(otelCertPathFlagName, "D", "", "tls certificate path to use with OpenTelemetry collector")
flags.StringP(otelKeyPathFlagName, "K", "", "tls key path to use with OpenTelemetry collector")
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate from disk")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate"+
"from disk")

_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
Expand Down
12 changes: 6 additions & 6 deletions flagd/pkg/certreloader/certreloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ type Config struct {
ReloadInterval time.Duration
}

type certReloader struct {
type CertReloader struct {
cert *tls.Certificate
mu sync.RWMutex
nextReload time.Time
Config
}

func NewCertReloader(config Config) (*certReloader, error) {
reloader := certReloader{
func NewCertReloader(config Config) (*CertReloader, error) {
reloader := CertReloader{
Config: config,
}

Expand All @@ -36,7 +36,7 @@ func NewCertReloader(config Config) (*certReloader, error) {
return &reloader, nil
}

func (r *certReloader) GetCertificate() (*tls.Certificate, error) {
func (r *CertReloader) GetCertificate() (*tls.Certificate, error) {
now := time.Now()
// Read locking here before we do the time comparison
// If a reload is in progress this will block and we will skip reloading in the current
Expand All @@ -59,8 +59,8 @@ func (r *certReloader) GetCertificate() (*tls.Certificate, error) {
return r.cert, nil
}

func (c *certReloader) loadCertificate() (tls.Certificate, error) {
newCert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath)
func (r *CertReloader) loadCertificate() (tls.Certificate, error) {
newCert, err := tls.LoadX509KeyPair(r.CertPath, r.KeyPath)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to load key pair: %w", err)
}
Expand Down
29 changes: 23 additions & 6 deletions flagd/pkg/certreloader/certreloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,22 @@ func generateValidCertificate(t *testing.T) (*bytes.Buffer, *bytes.Buffer) {

// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
err = pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
if err != nil {
t.Fatal(err)
}

caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
err = pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})
if err != nil {
t.Fatal(err)
}

// set up our server certificate
cert := &x509.Certificate{
Expand Down Expand Up @@ -228,16 +234,22 @@ func generateValidCertificate(t *testing.T) (*bytes.Buffer, *bytes.Buffer) {
}

certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
t.Fatal(err)
}

certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
t.Fatal(err)
}

return certPEM, certPrivKeyPEM
}
Expand Down Expand Up @@ -272,15 +284,20 @@ func generateValidCertificateFiles(t *testing.T) (string, string, func()) {
func copyFile(src, dst string) error {
data, err := os.ReadFile(src)
if err != nil {
return err
return fmt.Errorf("failed to load key pair: %w", err)
}

return os.WriteFile(dst, data, 0o777)
err = os.WriteFile(dst, data, 0o0600)
if err != nil {
return fmt.Errorf("failed to load key pair: %w", err)
}
return nil
}

func randString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, n)
//nolint:errcheck
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
Expand Down

0 comments on commit 71943e7

Please sign in to comment.