Skip to content

Commit

Permalink
Merge pull request #17 from kpumuk/custom-tls
Browse files Browse the repository at this point in the history
Added option to use custom TLS certificates instead of ACME
  • Loading branch information
kevinmcconnell authored Oct 2, 2024
2 parents bd188f2 + 30b8c20 commit f6086a7
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ applications. To enable this, add the `--tls` flag when deploying an instance:
kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls


### Custom TLS certificate

When you obtained your TLS certificate manually, manage your own certificate authority,
or need to install Cloudflare origin certificate, you can manually specify path to
your certificate file and the corresponding private key:

kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls --tls-certificate-path cert.pem --tls-private-key-path key.pem


## Specifying `run` options with environment variables

In some environments, like when running a Docker container, it can be convenient
Expand Down
3 changes: 3 additions & 0 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func newDeployCommand() *deployCommand {

deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)")
deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)")

deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target")
Expand All @@ -53,6 +55,7 @@ func newDeployCommand() *deployCommand {
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)")

deployCommand.cmd.MarkFlagRequired("target")
deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path")

return deployCommand
}
Expand Down
61 changes: 61 additions & 0 deletions internal/server/cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package server

import (
"crypto/tls"
"log/slog"
"net/http"
"sync"
)

type CertManager interface {
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
HTTPHandler(handler http.Handler) http.Handler
}

// StaticCertManager is a certificate manager that loads certificates from disk.
type StaticCertManager struct {
tlsCertificateFilePath string
tlsPrivateKeyFilePath string
cert *tls.Certificate
lock sync.RWMutex
}

func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager {
return &StaticCertManager{
tlsCertificateFilePath: tlsCertificateFilePath,
tlsPrivateKeyFilePath: tlsPrivateKeyFilePath,
}
}

func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
m.lock.RLock()
if m.cert != nil {
defer m.lock.RUnlock()
return m.cert, nil
}
m.lock.RUnlock()

m.lock.Lock()
defer m.lock.Unlock()
if m.cert != nil { // Double-check locking
return m.cert, nil
}

slog.Info(
"Loading custom TLS certificate",
"tls-certificate-path", m.tlsCertificateFilePath,
"tls-private-key-path", m.tlsPrivateKeyFilePath,
)

cert, err := tls.LoadX509KeyPair(m.tlsCertificateFilePath, m.tlsPrivateKeyFilePath)
if err != nil {
return nil, err
}
m.cert = &cert

return m.cert, nil
}

func (m *StaticCertManager) HTTPHandler(handler http.Handler) http.Handler {
return handler
}
105 changes: 105 additions & 0 deletions internal/server/cert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package server

import (
"crypto/tls"
"os"
"path"
"testing"

"github.com/stretchr/testify/require"
)

const certPem = `-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`

const keyPem = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`

func TestCertificateLoading(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
cert, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert)
}

func TestCertificateLoadingRaceCondition(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
go func() {
_, err2 := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err2)
}()
cert, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert)
}

func TestCachesLoadedCertificate(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert1)

require.Nil(t, os.Remove(certPath))
require.Nil(t, os.Remove(keyPath))

cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.Equal(t, cert1, cert2)
}

func TestErrorWhenFileDoesNotExist(t *testing.T) {
manager := NewStaticCertManager("testdata/cert.pem", "testdata/key.pem")
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.ErrorContains(t, err, "no such file or directory")
require.Nil(t, cert1)
}

func TestErrorWhenKeyFormatIsInvalid(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(keyPath, certPath)
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.ErrorContains(t, err, "failed to find certificate PEM data in certificate input")
require.Nil(t, cert1)
}

func prepareTestCertificateFiles(t *testing.T) (string, string, error) {
t.Helper()

dir := t.TempDir()
certFile := path.Join(dir, "example-cert.pem")
keyFile := path.Join(dir, "example-key.pem")

err := os.WriteFile(certFile, []byte(certPem), 0644)
if err != nil {
return "", "", err
}

err = os.WriteFile(keyFile, []byte(keyPem), 0644)
if err != nil {
return "", "", err
}

return certFile, keyFile, nil
}
18 changes: 12 additions & 6 deletions internal/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ type HealthCheckConfig struct {
}

type ServiceOptions struct {
TLSEnabled bool `json:"tls_enabled"`
ACMEDirectory string `json:"acme_directory"`
ACMECachePath string `json:"acme_cache_path"`
ErrorPagePath string `json:"error_page_path"`
TLSEnabled bool `json:"tls_enabled"`
TLSCertificatePath string `json:"tls_certificate_path"`
TLSPrivateKeyPath string `json:"tls_private_key_path"`
ACMEDirectory string `json:"acme_directory"`
ACMECachePath string `json:"acme_cache_path"`
ErrorPagePath string `json:"error_page_path"`
}

func (so ServiceOptions) ScopedCachePath() string {
Expand All @@ -90,7 +92,7 @@ type Service struct {

pauseController *PauseController
rolloutController *RolloutController
certManager *autocert.Manager
certManager CertManager
middleware http.Handler
}

Expand Down Expand Up @@ -284,11 +286,15 @@ func (s *Service) initialize() {
s.middleware = s.createMiddleware()
}

func (s *Service) createCertManager() *autocert.Manager {
func (s *Service) createCertManager() CertManager {
if !s.options.TLSEnabled {
return nil
}

if s.options.TLSCertificatePath != "" && s.options.TLSPrivateKeyPath != "" {
return NewStaticCertManager(s.options.TLSCertificatePath, s.options.TLSPrivateKeyPath)
}

return &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(s.options.ScopedCachePath()),
Expand Down
15 changes: 15 additions & 0 deletions internal/server/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) {
require.Equal(t, http.StatusOK, w.Result().StatusCode)
}

func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) {
service := testCreateService(
t,
[]string{"example.com"},
ServiceOptions{
TLSEnabled: true,
TLSCertificatePath: "cert.pem",
TLSPrivateKeyPath: "key.pem",
},
defaultTargetOptions,
)

require.IsType(t, &StaticCertManager{}, service.certManager)
}

func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) {
service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions)

Expand Down

0 comments on commit f6086a7

Please sign in to comment.