From 599c610d502a2219f4ad08a61128e9e1cd1d9aa7 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Thu, 19 Sep 2024 21:52:08 -0400 Subject: [PATCH 01/11] Added option to use custom TLS certificates instead of ACME --- README.md | 9 +++++++++ internal/cmd/deploy.go | 20 +++++++++++++++++--- internal/server/cert.go | 37 +++++++++++++++++++++++++++++++++++++ internal/server/service.go | 18 ++++++++++++------ 4 files changed, 75 insertions(+), 9 deletions(-) create mode 100644 internal/server/cert.go diff --git a/README.md b/README.md index 527acb9..a41c58b 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,15 @@ applications. To enable this, add the `--tls` flag when deploying an instance: kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls +### Custom TLS certificate + +When you obtained your TLS certificate manually, manage your own certificate authority, +or need to install Cloudflare origin certificate, you can manually specify path to +your certificate file and the corresponding private key: + + kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls --tls-certificate-path cert.pem --tls-private-key-path key.pem + + ## Specifying `run` options with environment variables In some environments, like when running a Docker container, it can be convenient diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index bb23f56..3271b07 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -10,9 +10,11 @@ import ( ) type deployCommand struct { - cmd *cobra.Command - args server.DeployArgs - tlsStaging bool + cmd *cobra.Command + args server.DeployArgs + tlsStaging bool + tlsCertificatePath string + tlsPrivateKeyPath string } func newDeployCommand() *deployCommand { @@ -31,6 +33,8 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") + deployCommand.cmd.Flags().StringVar(&deployCommand.tlsCertificatePath, "tls-certificate-path", "", "") + deployCommand.cmd.Flags().StringVar(&deployCommand.tlsPrivateKeyPath, "tls-private-key-path", "", "") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target") @@ -62,6 +66,8 @@ func (c *deployCommand) run(cmd *cobra.Command, args []string) error { if c.args.ServiceOptions.TLSEnabled { c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath() + c.args.ServiceOptions.TLSCertificatePath = c.tlsCertificatePath + c.args.ServiceOptions.TLSPrivateKeyPath = c.tlsPrivateKeyPath if c.tlsStaging { c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL @@ -87,6 +93,14 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("host must be set when using TLS") } + if cmd.Flags().Changed("tls-certificate-path") && !cmd.Flags().Changed("tls-private-key-path") { + return fmt.Errorf("tls-private-key-path must be set when specified tls-certificate-path") + } + + if cmd.Flags().Changed("tls-private-key-path") && !cmd.Flags().Changed("tls-certificate-path") { + return fmt.Errorf("tls-certificate-path must be set when specified tls-private-key-path") + } + if !cmd.Flags().Changed("forward-headers") { c.args.TargetOptions.ForwardHeaders = !c.args.ServiceOptions.TLSEnabled } diff --git a/internal/server/cert.go b/internal/server/cert.go new file mode 100644 index 0000000..bb1580d --- /dev/null +++ b/internal/server/cert.go @@ -0,0 +1,37 @@ +package server + +import ( + "crypto/tls" + "log/slog" +) + +type CertManager interface { + GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) +} + +type StaticCertManager struct { + tlsCertificateFilePath string + tlsPrivateKeyFilePath string +} + +func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager { + return &StaticCertManager{ + tlsCertificateFilePath: tlsCertificateFilePath, + tlsPrivateKeyFilePath: tlsPrivateKeyFilePath, + } +} + +func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + slog.Info( + "Loading custom TLS certificate", + "tls-certificate-path", m.tlsCertificateFilePath, + "tls-private-key-path", m.tlsPrivateKeyFilePath, + ) + + cert, err := tls.LoadX509KeyPair(m.tlsCertificateFilePath, m.tlsPrivateKeyFilePath) + if err != nil { + return nil, err + } + + return &cert, nil +} diff --git a/internal/server/service.go b/internal/server/service.go index 1c96de0..fe32427 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -60,10 +60,12 @@ type HealthCheckConfig struct { } type ServiceOptions struct { - TLSEnabled bool `json:"tls_enabled"` - ACMEDirectory string `json:"acme_directory"` - ACMECachePath string `json:"acme_cache_path"` - ErrorPagePath string `json:"error_page_path"` + TLSEnabled bool `json:"tls_enabled"` + TLSCertificatePath string `json:"tls_certificate_path"` + TLSPrivateKeyPath string `json:"tls_private_key_path"` + ACMEDirectory string `json:"acme_directory"` + ACMECachePath string `json:"acme_cache_path"` + ErrorPagePath string `json:"error_page_path"` } func (so ServiceOptions) ScopedCachePath() string { @@ -90,7 +92,7 @@ type Service struct { pauseController *PauseController rolloutController *RolloutController - certManager *autocert.Manager + certManager CertManager middleware http.Handler } @@ -284,11 +286,15 @@ func (s *Service) initialize() { s.middleware = s.createMiddleware() } -func (s *Service) createCertManager() *autocert.Manager { +func (s *Service) createCertManager() CertManager { if !s.options.TLSEnabled { return nil } + if s.options.TLSCertificatePath != "" && s.options.TLSPrivateKeyPath != "" { + return NewStaticCertManager(s.options.TLSCertificatePath, s.options.TLSPrivateKeyPath) + } + return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(s.options.ScopedCachePath()), From f1be70566294ec0dce24619b3c52e48f4b88eecf Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 11:55:47 -0400 Subject: [PATCH 02/11] Use ServiceOptions from deploy command directly --- internal/cmd/deploy.go | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 3271b07..9276245 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -10,11 +10,9 @@ import ( ) type deployCommand struct { - cmd *cobra.Command - args server.DeployArgs - tlsStaging bool - tlsCertificatePath string - tlsPrivateKeyPath string + cmd *cobra.Command + args server.DeployArgs + tlsStaging bool } func newDeployCommand() *deployCommand { @@ -33,8 +31,8 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") - deployCommand.cmd.Flags().StringVar(&deployCommand.tlsCertificatePath, "tls-certificate-path", "", "") - deployCommand.cmd.Flags().StringVar(&deployCommand.tlsPrivateKeyPath, "tls-private-key-path", "", "") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target") @@ -57,6 +55,7 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)") deployCommand.cmd.MarkFlagRequired("target") + deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path") return deployCommand } @@ -66,8 +65,6 @@ func (c *deployCommand) run(cmd *cobra.Command, args []string) error { if c.args.ServiceOptions.TLSEnabled { c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath() - c.args.ServiceOptions.TLSCertificatePath = c.tlsCertificatePath - c.args.ServiceOptions.TLSPrivateKeyPath = c.tlsPrivateKeyPath if c.tlsStaging { c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL @@ -93,14 +90,6 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("host must be set when using TLS") } - if cmd.Flags().Changed("tls-certificate-path") && !cmd.Flags().Changed("tls-private-key-path") { - return fmt.Errorf("tls-private-key-path must be set when specified tls-certificate-path") - } - - if cmd.Flags().Changed("tls-private-key-path") && !cmd.Flags().Changed("tls-certificate-path") { - return fmt.Errorf("tls-certificate-path must be set when specified tls-private-key-path") - } - if !cmd.Flags().Changed("forward-headers") { c.args.TargetOptions.ForwardHeaders = !c.args.ServiceOptions.TLSEnabled } From 085b924c13bf94d10d553c494d97593a18864430 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 11:56:47 -0400 Subject: [PATCH 03/11] Added test to make sure service uses static certificate manager when configured --- internal/server/service_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 8ae8b0c..967e136 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -42,6 +42,21 @@ func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) { require.Equal(t, http.StatusOK, w.Result().StatusCode) } +func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) { + service := testCreateService( + t, + []string{"example.com"}, + ServiceOptions{ + TLSEnabled: true, + TLSCertificatePath: "cert.pem", + TLSPrivateKeyPath: "key.pem", + }, + defaultTargetOptions, + ) + + require.NotNil(t, service.certManager.(*StaticCertManager)) +} + func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) { service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions) From aceaedea577ed60d5e99b03f52897ed2fa2869aa Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 11:57:06 -0400 Subject: [PATCH 04/11] Cache a certificate loaded from disk in the StaticCertManager --- internal/server/cert.go | 9 ++++- internal/server/cert_test.go | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 internal/server/cert_test.go diff --git a/internal/server/cert.go b/internal/server/cert.go index bb1580d..afaf740 100644 --- a/internal/server/cert.go +++ b/internal/server/cert.go @@ -9,9 +9,11 @@ type CertManager interface { GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) } +// StaticCertManager is a certificate manager that loads certificates from disk. type StaticCertManager struct { tlsCertificateFilePath string tlsPrivateKeyFilePath string + cert *tls.Certificate } func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager { @@ -22,6 +24,10 @@ func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) } func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if m.cert != nil { + return m.cert, nil + } + slog.Info( "Loading custom TLS certificate", "tls-certificate-path", m.tlsCertificateFilePath, @@ -32,6 +38,7 @@ func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certifica if err != nil { return nil, err } + m.cert = &cert - return &cert, nil + return m.cert, nil } diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go new file mode 100644 index 0000000..b3de207 --- /dev/null +++ b/internal/server/cert_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "crypto/tls" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +const certPem = `-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----` + +const keyPem = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----` + +func TestCertificateLoading(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles() + require.NoError(t, err) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + manager := NewStaticCertManager(certPath, keyPath) + cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) +} + +func TestCachesLoadedCertificate(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles() + require.NoError(t, err) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + manager := NewStaticCertManager(certPath, keyPath) + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert1) + + os.Remove(certPath) + os.Remove(keyPath) + + cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.Equal(t, cert1, cert2) +} + +func prepareTestCertificateFiles() (string, string, error) { + certFile, err := os.CreateTemp("", "example-cert-*.pem") + if err != nil { + return "", "", err + } + defer certFile.Close() + certFile.Write([]byte(certPem)) + + keyFile, err := os.CreateTemp("", "example-key-*.pem") + if err != nil { + return "", "", err + } + defer keyFile.Close() + keyFile.Write([]byte(keyPem)) + + return certFile.Name(), keyFile.Name(), nil +} From 793779168d8fd682d07bb330613040ead1b0bf46 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 12:01:59 -0400 Subject: [PATCH 05/11] Added tests for invalid scenarios in the StaticCertManager --- internal/server/cert_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index b3de207..6599d11 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -56,6 +56,25 @@ func TestCachesLoadedCertificate(t *testing.T) { require.Equal(t, cert1, cert2) } +func TestErrorWhenFileDoesNotExist(t *testing.T) { + manager := NewStaticCertManager("testdata/cert.pem", "testdata/key.pem") + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.ErrorContains(t, err, "no such file or directory") + require.Nil(t, cert1) +} + +func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles() + require.NoError(t, err) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + manager := NewStaticCertManager(keyPath, certPath) + cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.ErrorContains(t, err, "failed to find certificate PEM data in certificate input") + require.Nil(t, cert1) +} + func prepareTestCertificateFiles() (string, string, error) { certFile, err := os.CreateTemp("", "example-cert-*.pem") if err != nil { From 233625e45bc3d8f99c5a675621900a3071f17ed7 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 12:04:32 -0400 Subject: [PATCH 06/11] More expressive test for the cert manager type in the service test --- internal/server/service_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 967e136..ae82027 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -54,7 +54,7 @@ func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) { defaultTargetOptions, ) - require.NotNil(t, service.certManager.(*StaticCertManager)) + require.IsType(t, &StaticCertManager{}, service.certManager) } func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) { From 3d24ed2ca0b647830e6a462c27e0dfd360b3c0d3 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Tue, 24 Sep 2024 12:41:29 -0400 Subject: [PATCH 07/11] Addressed a data race condition in the StaticCertManager, and added a test to confirm it addressed --- internal/server/cert.go | 11 +++++++++++ internal/server/cert_test.go | 15 +++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/internal/server/cert.go b/internal/server/cert.go index afaf740..90e8b48 100644 --- a/internal/server/cert.go +++ b/internal/server/cert.go @@ -3,6 +3,7 @@ package server import ( "crypto/tls" "log/slog" + "sync" ) type CertManager interface { @@ -14,6 +15,7 @@ type StaticCertManager struct { tlsCertificateFilePath string tlsPrivateKeyFilePath string cert *tls.Certificate + lock sync.RWMutex } func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager { @@ -24,7 +26,16 @@ func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) } func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + m.lock.RLock() if m.cert != nil { + defer m.lock.RUnlock() + return m.cert, nil + } + m.lock.RUnlock() + + m.lock.Lock() + defer m.lock.Unlock() + if m.cert != nil { // Double-check locking return m.cert, nil } diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index 6599d11..9ce422a 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -38,6 +38,21 @@ func TestCertificateLoading(t *testing.T) { require.NotNil(t, cert) } +func TestCertificateLoadingRaceCondition(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles() + require.NoError(t, err) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + manager := NewStaticCertManager(certPath, keyPath) + go func() { + manager.GetCertificate(&tls.ClientHelloInfo{}) + }() + cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) +} + func TestCachesLoadedCertificate(t *testing.T) { certPath, keyPath, err := prepareTestCertificateFiles() require.NoError(t, err) From 749e110bc3a14c438e32396a4d88e3ec11612735 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Wed, 25 Sep 2024 13:43:53 -0400 Subject: [PATCH 08/11] Use t.Cleanup() to remove temporary certificate files after the test is completed --- internal/server/cert_test.go | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index 9ce422a..d6c02ff 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -27,10 +27,8 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== -----END EC PRIVATE KEY-----` func TestCertificateLoading(t *testing.T) { - certPath, keyPath, err := prepareTestCertificateFiles() + certPath, keyPath, err := prepareTestCertificateFiles(t) require.NoError(t, err) - defer os.Remove(certPath) - defer os.Remove(keyPath) manager := NewStaticCertManager(certPath, keyPath) cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) @@ -39,10 +37,8 @@ func TestCertificateLoading(t *testing.T) { } func TestCertificateLoadingRaceCondition(t *testing.T) { - certPath, keyPath, err := prepareTestCertificateFiles() + certPath, keyPath, err := prepareTestCertificateFiles(t) require.NoError(t, err) - defer os.Remove(certPath) - defer os.Remove(keyPath) manager := NewStaticCertManager(certPath, keyPath) go func() { @@ -54,10 +50,8 @@ func TestCertificateLoadingRaceCondition(t *testing.T) { } func TestCachesLoadedCertificate(t *testing.T) { - certPath, keyPath, err := prepareTestCertificateFiles() + certPath, keyPath, err := prepareTestCertificateFiles(t) require.NoError(t, err) - defer os.Remove(certPath) - defer os.Remove(keyPath) manager := NewStaticCertManager(certPath, keyPath) cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) @@ -79,10 +73,8 @@ func TestErrorWhenFileDoesNotExist(t *testing.T) { } func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { - certPath, keyPath, err := prepareTestCertificateFiles() + certPath, keyPath, err := prepareTestCertificateFiles(t) require.NoError(t, err) - defer os.Remove(certPath) - defer os.Remove(keyPath) manager := NewStaticCertManager(keyPath, certPath) cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{}) @@ -90,13 +82,16 @@ func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { require.Nil(t, cert1) } -func prepareTestCertificateFiles() (string, string, error) { +func prepareTestCertificateFiles(t *testing.T) (string, string, error) { + t.Helper() + certFile, err := os.CreateTemp("", "example-cert-*.pem") if err != nil { return "", "", err } defer certFile.Close() certFile.Write([]byte(certPem)) + t.Cleanup(func() { os.Remove(certFile.Name()) }) keyFile, err := os.CreateTemp("", "example-key-*.pem") if err != nil { @@ -104,6 +99,7 @@ func prepareTestCertificateFiles() (string, string, error) { } defer keyFile.Close() keyFile.Write([]byte(keyPem)) + t.Cleanup(func() { os.Remove(keyFile.Name()) }) return certFile.Name(), keyFile.Name(), nil } From 7e7126ca221508ba318bff882e7dbd0b87130b7e Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Wed, 25 Sep 2024 14:03:13 -0400 Subject: [PATCH 09/11] Ensure all IO errors are handled in the StaticCertManager test --- internal/server/cert_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index d6c02ff..96ed045 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/tls" "os" + "path" "testing" "github.com/stretchr/testify/require" @@ -42,7 +43,8 @@ func TestCertificateLoadingRaceCondition(t *testing.T) { manager := NewStaticCertManager(certPath, keyPath) go func() { - manager.GetCertificate(&tls.ClientHelloInfo{}) + _, err2 := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err2) }() cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) require.NoError(t, err) @@ -58,8 +60,8 @@ func TestCachesLoadedCertificate(t *testing.T) { require.NoError(t, err) require.NotNil(t, cert1) - os.Remove(certPath) - os.Remove(keyPath) + require.Nil(t, os.Remove(certPath)) + require.Nil(t, os.Remove(keyPath)) cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{}) require.Equal(t, cert1, cert2) @@ -85,21 +87,19 @@ func TestErrorWhenKeyFormatIsInvalid(t *testing.T) { func prepareTestCertificateFiles(t *testing.T) (string, string, error) { t.Helper() - certFile, err := os.CreateTemp("", "example-cert-*.pem") + dir := t.TempDir() + certFile := path.Join(dir, "example-cert.pem") + keyFile := path.Join(dir, "example-key.pem") + + err := os.WriteFile(certFile, []byte(certPem), 0644) if err != nil { return "", "", err } - defer certFile.Close() - certFile.Write([]byte(certPem)) - t.Cleanup(func() { os.Remove(certFile.Name()) }) - keyFile, err := os.CreateTemp("", "example-key-*.pem") + err = os.WriteFile(keyFile, []byte(keyPem), 0644) if err != nil { return "", "", err } - defer keyFile.Close() - keyFile.Write([]byte(keyPem)) - t.Cleanup(func() { os.Remove(keyFile.Name()) }) - return certFile.Name(), keyFile.Name(), nil + return certFile, keyFile, nil } From f2f6f9c0240898dcc11204aff93f9afe676ef145 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Mon, 30 Sep 2024 10:14:24 -0400 Subject: [PATCH 10/11] Added missing help for --tls-certificate-path and --tls-private-key-path arguments --- internal/cmd/deploy.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 9276245..82ae5c4 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -31,8 +31,8 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") - deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "") - deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target") From 30b8c2048ab9f0430417a552030e6d8279f70f28 Mon Sep 17 00:00:00 2001 From: Dmytro Shteflyuk Date: Mon, 30 Sep 2024 10:22:34 -0400 Subject: [PATCH 11/11] Added HTTPHandler() to the CertManager interface after support for HTTP-01 challenges was added --- internal/server/cert.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/server/cert.go b/internal/server/cert.go index 90e8b48..493a5ff 100644 --- a/internal/server/cert.go +++ b/internal/server/cert.go @@ -3,11 +3,13 @@ package server import ( "crypto/tls" "log/slog" + "net/http" "sync" ) type CertManager interface { GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) + HTTPHandler(handler http.Handler) http.Handler } // StaticCertManager is a certificate manager that loads certificates from disk. @@ -53,3 +55,7 @@ func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certifica return m.cert, nil } + +func (m *StaticCertManager) HTTPHandler(handler http.Handler) http.Handler { + return handler +}