Skip to content

Commit

Permalink
Use GetClientCertificate to allow client cert to be reloaded
Browse files Browse the repository at this point in the history
Signed-off-by: Ruben Vargas <[email protected]>
  • Loading branch information
rubenvp8510 committed Jul 19, 2024
1 parent f5bd383 commit bcfd8c4
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@
* [ENHANCEMENT] memberlist: use separate queue for broadcast messages that are result of local updates, and prioritize locally-generated messages when sending broadcasts. On stopping, only wait for queue with locally-generated messages to be empty. #539
* [ENHANCEMENT] memberlist: Added `-<prefix>memberlist.broadcast-timeout-for-local-updates-on-shutdown` option to set timeout for sending locally-generated updates on shutdown, instead of previously hardcoded 10s (which is still the default). #539
* [ENHANCEMENT] tracing: add ExtractTraceSpanID function.
* [EHNANCEMENT] Supporting reloading client certificates #549
* [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538
* [BUGFIX] spanlogger: Support multiple tenant IDs. #59
* [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85
Expand Down
1 change: 1 addition & 0 deletions crypto/tls/test/tls_integration_go_1_20_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ package test

// The error message changed in Go 1.21.
const badCertificateErrorMessage = "remote error: tls: bad certificate"
const mismatchCAAndCerts = "remote error: tls: unknown certificate authority"
1 change: 1 addition & 0 deletions crypto/tls/test/tls_integration_go_1_21_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ package test

// The error message changed in Go 1.21.
const badCertificateErrorMessage = "remote error: tls: certificate required"
const mismatchCAAndCerts = "remote error: tls: unknown certificate authority"
4 changes: 3 additions & 1 deletion crypto/tls/test/tls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA
// bad certificate from the server side and just see connection
// closed/reset instead
badCertErr := errorContainsString(badCertificateErrorMessage)
mismatchCAAndCertsErr := errorContainsString(mismatchCAAndCerts)

newIntegrationClientServer(
t,
cfg,
Expand Down Expand Up @@ -411,7 +413,7 @@ func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA
CertPath: certs.client2CertFile,
KeyPath: certs.client2KeyFile,
},
httpExpectError: badCertErr,
httpExpectError: mismatchCAAndCertsErr,
grpcExpectError: unavailableDescErr,
},
},
Expand Down
62 changes: 44 additions & 18 deletions crypto/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ func (cfg *ClientConfig) GetTLSCipherSuitesLongDescription() string {
return text
}

func (cfg *ClientConfig) validateCertificatePaths() (error, bool) {

Check failure on line 86 in crypto/tls/tls.go

View workflow job for this annotation

GitHub Actions / Check

error-return: error should be the last type when returning multiple items (revive)
if cfg.CertPath != "" || cfg.KeyPath != "" {
if cfg.CertPath == "" {
return errCertMissing, true
}
if cfg.KeyPath == "" {
return errKeyMissing, true
}
return nil, true
}
return nil, false
}

// GetTLSConfig initialises tls.Config from config options
func (cfg *ClientConfig) GetTLSConfig() (*tls.Config, error) {
config := &tls.Config{
Expand All @@ -109,29 +122,42 @@ func (cfg *ClientConfig) GetTLSConfig() (*tls.Config, error) {
config.RootCAs = caCertPool
}

// Read Client Certificate
if cfg.CertPath != "" || cfg.KeyPath != "" {
if cfg.CertPath == "" {
return nil, errCertMissing
}
if cfg.KeyPath == "" {
return nil, errKeyMissing
loadCert := func() (*tls.Certificate, error) {
// not used boolean, assumed if this is called is because we already configured TLS Client certificates.
err, _ := cfg.validateCertificatePaths()
if err == nil {
cert, err := reader.ReadSecret(cfg.CertPath)
if err != nil {
return nil, errors.Wrapf(err, "error loading client cert: %s", cfg.CertPath)
}
key, err := reader.ReadSecret(cfg.KeyPath)
if err != nil {
return nil, errors.Wrapf(err, "error loading client key: %s", cfg.KeyPath)
}

clientCert, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, errors.Wrapf(err, "failed to load TLS certificate %s,%s", cfg.CertPath, cfg.KeyPath)
}
return &clientCert, nil
}
return nil, err
}

cert, err := reader.ReadSecret(cfg.CertPath)
if err != nil {
return nil, errors.Wrapf(err, "error loading client cert: %s", cfg.CertPath)
}
key, err := reader.ReadSecret(cfg.KeyPath)
if err != nil {
return nil, errors.Wrapf(err, "error loading client key: %s", cfg.KeyPath)
err, useClientCerts := cfg.validateCertificatePaths()
if err != nil {
return nil, err
}

if useClientCerts {
// Confirm that certificate and key paths are valid.
if _, err := loadCert(); err != nil {
return nil, err
}

clientCert, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, errors.Wrapf(err, "failed to load TLS certificate %s,%s", cfg.CertPath, cfg.KeyPath)
config.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return loadCert()
}
config.Certificates = []tls.Certificate{clientCert}
}

if cfg.MinVersion != "" {
Expand Down
2 changes: 1 addition & 1 deletion crypto/tls/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestGetTLSConfig_ClientCerts(t *testing.T) {
tlsConfig, err := c.GetTLSConfig()
assert.NoError(t, err)
assert.Equal(t, false, tlsConfig.InsecureSkipVerify, "make sure we default to not skip verification")
assert.Equal(t, 1, len(tlsConfig.Certificates), "ensure a certificate is returned")
assert.NotNil(t, tlsConfig.GetClientCertificate, "ensure a get certificate function is returned")

// expect error with key and cert swapped passed along
c = &ClientConfig{
Expand Down

0 comments on commit bcfd8c4

Please sign in to comment.