From cf8281a89be333e568f33e6052c63b7cbef91e14 Mon Sep 17 00:00:00 2001 From: Max Lambrecht Date: Thu, 4 May 2023 11:41:16 -0500 Subject: [PATCH] Add TLS communication between Galadriel Server and Harvester (#146) Add TLS communication between Server and Harvester Signed-off-by: Max Lambrecht --- cmd/harvester/cli/config.go | 10 +- cmd/harvester/cli/config_test.go | 2 + conf/harvester/dummy_root_ca.crt | 32 ++++++ conf/harvester/harvester.conf | 3 + pkg/common/constants/constants.go | 5 + pkg/common/x509ca/disk/disk.go | 7 ++ pkg/harvester/client/server.go | 53 +++++++-- pkg/harvester/config.go | 5 +- pkg/harvester/controller/controller.go | 8 +- pkg/harvester/harvester.go | 10 +- pkg/server/endpoints/config.go | 4 +- pkg/server/endpoints/run.go | 151 ++++++++++++++++++++++--- pkg/server/endpoints/run_test.go | 96 +++++++++++++++- pkg/server/server.go | 21 ++-- 14 files changed, 361 insertions(+), 46 deletions(-) create mode 100644 conf/harvester/dummy_root_ca.crt create mode 100644 pkg/common/constants/constants.go diff --git a/cmd/harvester/cli/config.go b/cmd/harvester/cli/config.go index 82d4ece0e..ac7150b75 100644 --- a/cmd/harvester/cli/config.go +++ b/cmd/harvester/cli/config.go @@ -3,6 +3,7 @@ package cli import ( "fmt" "io" + "net" "time" "github.com/HewlettPackard/galadriel/pkg/common/telemetry" @@ -28,6 +29,7 @@ type Config struct { type harvesterConfig struct { SpireSocketPath string `hcl:"spire_socket_path"` ServerAddress string `hcl:"server_address"` + ServerTrustBundlePath string `hcl:"server_trust_bundle_path"` BundleUpdatesInterval string `hcl:"bundle_updates_interval"` LogLevel string `hcl:"log_level"` } @@ -61,8 +63,14 @@ func NewHarvesterConfig(c *Config) (*harvester.Config, error) { return nil, fmt.Errorf("failed to parse bundle updates interval: %v", err) } + serverTCPAddress, err := net.ResolveTCPAddr("tcp", c.Harvester.ServerAddress) + if err != nil { + return nil, fmt.Errorf("failed to resolve server address: %v", err) + } + hc.SpireAddress = spireAddr - hc.ServerAddress = c.Harvester.ServerAddress + hc.ServerAddress = serverTCPAddress + hc.ServerTrustBundlePath = c.Harvester.ServerTrustBundlePath hc.BundleUpdatesInterval = buInt hc.Logger = logrus.WithField(telemetry.SubsystemName, telemetry.Harvester) diff --git a/cmd/harvester/cli/config_test.go b/cmd/harvester/cli/config_test.go index 7f1e458cd..e2988291f 100644 --- a/cmd/harvester/cli/config_test.go +++ b/cmd/harvester/cli/config_test.go @@ -1 +1,3 @@ package cli + +// TODO: add tests diff --git a/conf/harvester/dummy_root_ca.crt b/conf/harvester/dummy_root_ca.crt new file mode 100644 index 000000000..78da737a1 --- /dev/null +++ b/conf/harvester/dummy_root_ca.crt @@ -0,0 +1,32 @@ +-----BEGIN CERTIFICATE----- +MIIFkjCCA3qgAwIBAgIUaCjXS2q5FyByi//FJRIoVoxwNLowDQYJKoZIhvcNAQEL +BQAwMTELMAkGA1UEBhMCVVMxDzANBgNVBAoMBlNQSUZGRTERMA8GA1UEAwwIcGVy +ZnRlc3QwHhcNMjMwMzA4MTgwMjAxWhcNMzQwNTI1MTgwMjAxWjAxMQswCQYDVQQG +EwJVUzEPMA0GA1UECgwGU1BJRkZFMREwDwYDVQQDDAhwZXJmdGVzdDCCAiIwDQYJ +KoZIhvcNAQEBBQADggIPADCCAgoCggIBAN3NNPUn6NK7PthpCcGMboyWU/jusjHc +DALVbPsnGaSq2fup/d+7bQ7ElY4R26fK7KibIh8S1s6uGgrjBnOKkJRmVhxxtqZC +C8devw/zWEdEAY5VeVHlN7gzKBJpmy3U3HBwhx6eFqwB85enz/9Y6FgWU5jbTkw5 +SyPGMo+jiUPWBhaKH2tGoa3kujbrTcCjyRWYu6jyk6ZjSyJXWriY526HZwaFl6/c +C64AYGdnRSPAFYKi0s4FILd11kbgno2UqZqUb+LkuYyrMBMd9VPjISQKGmKprcIU +6KY79LTELse5lOHFmJWbELpUtyWOglWEpTlwWdr4AmJwcfS9po6FY2CLmDIi9Xgy +aLYJwLh4xllqSg1GCx44FIeOqIVjJYNbWxF1EHPk9ZCaRNemFV6PRns4ePCZnjtp +hZjSXxkIXsJfu5CHjn88OIEby0rxNCFREMZZ/o22UXI9Iuq72/SjS4nT5W7Wo/K3 +m4YvrbfIBVxPTwoK4w5fLXJPa5rcPgRbtch3UQX6/gLY3akzfFH+mGlv2V06fxRx +hbrQIcZGyPf5TvugW0UIcFLv3RW1ScxQjcvajPDEZzpLArLiC/lnD3rmXUgau7SQ +4nUqHT/STL7KPDu0XMa7omSQfqljwed978Rg1gqNiKOB0WSu9+phGxLakad5/k+e +mbtLkG0RZiA5AgMBAAGjgaEwgZ4wDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU +E13a4pFoH0GGLDghCoqJ3N68kJIwbAYDVR0jBGUwY4AUE13a4pFoH0GGLDghCoqJ +3N68kJKhNaQzMDExCzAJBgNVBAYTAlVTMQ8wDQYDVQQKDAZTUElGRkUxETAPBgNV +BAMMCHBlcmZ0ZXN0ghRoKNdLarkXIHKL/8UlEihWjHA0ujANBgkqhkiG9w0BAQsF +AAOCAgEAh4rjrwnUI0gEtHgKXCnx5rcT2DBAy/9xAOxdsDOVqheiDWoC0NK9Iet8 +m/01uwY5v+PJ4gmdVGO3V/FOvb3KOAiilZPFmI0jUYRpceUTh7FkRsZPLFeGYJ7K +OPJ6BzX1Y+aHcMK6jaTtHrgxS2gt1cWEB4nSYFQKKVqCmExrZpiFDw0ZY3RTFXSc +LA1jV1/1OD05G6UxdHSNbTzQRlNaLPeE+EL9eyiIF2RxvnQul6qZwed35nh1TqjN +n+nNSH8LSZxdwQLNEn0CYQ2LuUeAr6EdvfFibse2fyTU+HBSyuH4SRXjvFC4WuSr +pPtnGPtKwVVCNhCYiixOMPCTOL68ZWK7XC7AaHSX7pD7DuRZU2WfAJ4gSD/jUJel +ku2Kxl2NXeXrlkzMs6Ud022mLL/sFSd4WLuB/G+fk7wgwdezk4syISHSGQ0+Xech +QqpvPqj/+hi3C2yUxzZojxVSBJkB9+Z8LFrk8qpO0VREO5KD/9444Ric/CTgjQmg +qfGlCd4P+2YYEd8em7VIE5d/yDJ8rGLT45d7AwepWJzXEZBb1DMDcGVbSSISDyCh +ZAmmTw37cxEfp+7nKAGor64lX3DbRiPQV/MGv6MhNPNRyBPH7kkeZuE6ewh8bdfV +8YAhxXpLl11iLC4bZbj2AGlZ+Cr0WEVaX7g7S82cdrmc208zBR8= +-----END CERTIFICATE----- diff --git a/conf/harvester/harvester.conf b/conf/harvester/harvester.conf index d9fb8b95a..1c1a2f12f 100644 --- a/conf/harvester/harvester.conf +++ b/conf/harvester/harvester.conf @@ -7,6 +7,9 @@ harvester { # E.g: localhost:8085, my-upstream-server.com:4556, 192.168.1.125:4000 server_address = "localhost:8085" + # server_trust_bundle_path: Path to the Galadriel Server CA bundle. + server_trust_bundle_path = "./conf/harvester/dummy_root_ca.crt" + # bundle_updates_interval: Sets how often to check for bundle rotation. # Typically this should be less than the CA TTL set in SPIRE. # Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". diff --git a/pkg/common/constants/constants.go b/pkg/common/constants/constants.go new file mode 100644 index 000000000..a4c567583 --- /dev/null +++ b/pkg/common/constants/constants.go @@ -0,0 +1,5 @@ +package constants + +const ( + GaladrielServerName = "galadriel-server" +) diff --git a/pkg/common/x509ca/disk/disk.go b/pkg/common/x509ca/disk/disk.go index e3b46e3ae..dde8c5520 100644 --- a/pkg/common/x509ca/disk/disk.go +++ b/pkg/common/x509ca/disk/disk.go @@ -84,6 +84,13 @@ func (ca *X509CA) Configure(config *Config) error { // IssueX509Certificate issues an X509 certificate using the disk-based private key and ROOT CA certificate. The certificate // is bound to the given public key and subject. func (ca *X509CA) IssueX509Certificate(ctx context.Context, params *x509ca.X509CertificateParams) ([]*x509.Certificate, error) { + if params.PublicKey == nil { + return nil, errors.New("public key is required") + } + if params.TTL == 0 { + return nil, errors.New("TTL is required") + } + template, err := cryptoutil.CreateX509Template(ca.clock, params.PublicKey, params.Subject, params.URIs, params.DNSNames, params.TTL) if err != nil { return nil, fmt.Errorf("failed to create template for Server certificate: %w", err) diff --git a/pkg/harvester/client/server.go b/pkg/harvester/client/server.go index 9a9ef1875..922aaccdd 100644 --- a/pkg/harvester/client/server.go +++ b/pkg/harvester/client/server.go @@ -3,12 +3,17 @@ package client import ( "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" + "net" "net/http" + "os" "github.com/HewlettPackard/galadriel/pkg/common" + "github.com/HewlettPackard/galadriel/pkg/common/constants" "github.com/HewlettPackard/galadriel/pkg/common/telemetry" "github.com/sirupsen/logrus" ) @@ -29,23 +34,31 @@ type GaladrielServerClient interface { } type client struct { - c http.Client - address string + c *http.Client + address *net.TCPAddr token string logger logrus.FieldLogger } -func NewGaladrielServerClient(address, token string) (GaladrielServerClient, error) { +// NewGaladrielServerClient creates a new Galadriel Server client, using the given token to authenticate +// and the given trustBundlePath to validate the server certificate. +func NewGaladrielServerClient(address *net.TCPAddr, token string, trustBundlePath string) (GaladrielServerClient, error) { + c, err := createTLSClient(trustBundlePath) + if err != nil { + return nil, fmt.Errorf("failed to create TLS client: %w", err) + } + return &client{ - c: *http.DefaultClient, - address: "http://" + address, + c: c, + address: address, token: token, logger: logrus.WithField(telemetry.SubsystemName, telemetry.GaladrielServerClient), }, nil } func (c *client) Connect(ctx context.Context, token string) error { - url := c.address + onboardPath + url := fmt.Sprintf("https://%s%s", c.address.String(), onboardPath) + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, url, nil) if err != nil { return err @@ -77,7 +90,7 @@ func (c *client) SyncFederatedBundles(ctx context.Context, req *common.SyncBundl } c.logger.Debugf("Sending post federated bundles updates:\n%s", b) - url := c.address + postBundleSyncPath + url := c.address.String() + postBundleSyncPath r, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) @@ -117,7 +130,7 @@ func (c *client) PostBundle(ctx context.Context, req *common.PostBundleRequest) return fmt.Errorf("failed to marshal push bundle request: %v", err) } - url := c.address + postBundlePath + url := c.address.String() + postBundlePath r, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) if err != nil { @@ -147,6 +160,30 @@ func (c *client) PostBundle(ctx context.Context, req *common.PostBundleRequest) return nil } +func createTLSClient(trustBundlePath string) (*http.Client, error) { + caCert, err := os.ReadFile(trustBundlePath) + if err != nil { + return nil, err + } + + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(caCert) + if !ok { + return nil, fmt.Errorf("failed to append CA certificates") + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + ServerName: constants.GaladrielServerName, + }, + } + + return &http.Client{ + Transport: tr, + }, nil +} + func readBody(resp *http.Response) (string, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { diff --git a/pkg/harvester/config.go b/pkg/harvester/config.go index 189d21cdd..dcea6be1e 100644 --- a/pkg/harvester/config.go +++ b/pkg/harvester/config.go @@ -16,7 +16,7 @@ type Config struct { LocalAddress net.Addr // Address of Galadriel server - ServerAddress string + ServerAddress *net.TCPAddr // Address of SPIRE Server SpireAddress net.Addr @@ -30,5 +30,8 @@ type Config struct { // Directory to store runtime data DataDir string + // Path to the trust bundle for the Galadriel Server + ServerTrustBundlePath string + Logger logrus.FieldLogger } diff --git a/pkg/harvester/controller/controller.go b/pkg/harvester/controller/controller.go index e4c62ac12..ca282d5f5 100644 --- a/pkg/harvester/controller/controller.go +++ b/pkg/harvester/controller/controller.go @@ -26,23 +26,19 @@ type HarvesterController struct { // Config represents the configurations for the Harvester Controller type Config struct { - ServerAddress string SpireSocketPath net.Addr AccessToken string BundleUpdatesInterval time.Duration Logger logrus.FieldLogger + GaladrielServerClient client.GaladrielServerClient } func NewHarvesterController(ctx context.Context, config *Config) (*HarvesterController, error) { sc := spire.NewLocalSpireServer(ctx, config.SpireSocketPath) - gc, err := client.NewGaladrielServerClient(config.ServerAddress, config.AccessToken) - if err != nil { - return nil, err - } return &HarvesterController{ spire: sc, - server: gc, + server: config.GaladrielServerClient, config: config, logger: logrus.WithField(telemetry.SubsystemName, telemetry.HarvesterController), }, nil diff --git a/pkg/harvester/harvester.go b/pkg/harvester/harvester.go index 09d3a490c..8a175e4b8 100644 --- a/pkg/harvester/harvester.go +++ b/pkg/harvester/harvester.go @@ -3,6 +3,7 @@ package harvester import ( "context" "errors" + "fmt" "github.com/HewlettPackard/galadriel/pkg/common/telemetry" "github.com/HewlettPackard/galadriel/pkg/common/util" @@ -31,18 +32,18 @@ func (h *Harvester) Run(ctx context.Context) error { return errors.New("token is required to connect the Harvester to the Galadriel Server") } - galadrielClient, err := client.NewGaladrielServerClient(h.config.ServerAddress, h.config.JoinToken) + galadrielClient, err := client.NewGaladrielServerClient(h.config.ServerAddress, h.config.JoinToken, h.config.ServerTrustBundlePath) if err != nil { - return err + return fmt.Errorf("failed to create Galadriel Server client: %w", err) } err = galadrielClient.Connect(ctx, h.config.JoinToken) if err != nil { - return err + return fmt.Errorf("failed to connect to Galadriel Server: %w", err) } config := &controller.Config{ - ServerAddress: h.config.ServerAddress, + GaladrielServerClient: galadrielClient, SpireSocketPath: h.config.SpireAddress, AccessToken: h.config.JoinToken, BundleUpdatesInterval: h.config.BundleUpdatesInterval, @@ -60,6 +61,7 @@ func (h *Harvester) Run(ctx context.Context) error { return err } +// TODO: implement this or remove it func (h *Harvester) Stop() { // unload and cleanup stuff } diff --git a/pkg/server/endpoints/config.go b/pkg/server/endpoints/config.go index e448d54d6..9d67cac61 100644 --- a/pkg/server/endpoints/config.go +++ b/pkg/server/endpoints/config.go @@ -2,6 +2,7 @@ package endpoints import ( "github.com/HewlettPackard/galadriel/pkg/server/catalog" + "github.com/HewlettPackard/galadriel/pkg/server/datastore" "net" "github.com/sirupsen/logrus" @@ -15,8 +16,7 @@ type Config struct { // LocalAddress is the local address to bind the listener to. LocalAddress net.Addr - // Postgres connection string - DatastoreConnString string + Datastore datastore.Datastore Logger logrus.FieldLogger diff --git a/pkg/server/endpoints/run.go b/pkg/server/endpoints/run.go index 36bc7eec0..f8db168d6 100644 --- a/pkg/server/endpoints/run.go +++ b/pkg/server/endpoints/run.go @@ -1,19 +1,32 @@ +// TODO: rename this file to endpoints.go package endpoints import ( "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509/pkix" + "errors" "fmt" - "github.com/HewlettPackard/galadriel/pkg/server/catalog" "net" "net/http" + "sync" + "time" + "github.com/HewlettPackard/galadriel/pkg/common/constants" + "github.com/HewlettPackard/galadriel/pkg/common/cryptoutil" "github.com/HewlettPackard/galadriel/pkg/common/util" + "github.com/HewlettPackard/galadriel/pkg/common/x509ca" "github.com/HewlettPackard/galadriel/pkg/server/datastore" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/sirupsen/logrus" ) +const ( + defaultTTL = 1 * time.Hour +) + // Server manages the UDS and TCP endpoints lifecycle type Server interface { // ListenAndServe starts all endpoint servers and blocks until the context @@ -22,11 +35,24 @@ type Server interface { } type Endpoints struct { + // TODO: unexport these fields TCPAddress *net.TCPAddr LocalAddr net.Addr Datastore datastore.Datastore Logger logrus.FieldLogger - catalog catalog.Catalog + + x509CA x509ca.X509CA + certsStore *certificateSource + + hooks struct { + // test hook used to signal that TCP listener is ready + tcpListening chan struct{} + } +} + +type certificateSource struct { + mu sync.RWMutex + cert *tls.Certificate } func New(c *Config) (*Endpoints, error) { @@ -34,17 +60,12 @@ func New(c *Config) (*Endpoints, error) { return nil, err } - ds, err := datastore.NewSQLDatastore(c.Logger, c.DatastoreConnString) - if err != nil { - return nil, err - } - return &Endpoints{ TCPAddress: c.TCPAddress, LocalAddr: c.LocalAddress, - Datastore: ds, + Datastore: c.Datastore, Logger: c.Logger, - catalog: c.Catalog, + x509CA: c.Catalog.GetX509CA(), }, nil } @@ -53,11 +74,11 @@ func (e *Endpoints) ListenAndServe(ctx context.Context) error { e.runTCPServer, e.runUDSServer, ) - if err != nil { - return err + if errors.Is(err, context.Canceled) { + err = nil } - return nil + return err } func (e *Endpoints) runTCPServer(ctx context.Context) error { @@ -71,20 +92,48 @@ func (e *Endpoints) runTCPServer(ctx context.Context) error { return e.validateToken(c, key) })) - e.Logger.Infof("Starting TCP Server on %s", e.TCPAddress.String()) + cert, err := e.getTLSCertificate(ctx) + if err != nil { + return fmt.Errorf("failed to start TCP listener: %w", err) + } + e.certsStore = &certificateSource{cert: cert} + + tlsConfig := &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return e.certsStore.getTLSCertificate(), nil + }, + } + + httpServer := http.Server{ + Addr: e.TCPAddress.String(), + Handler: server, // set Echo as handler + TLSConfig: tlsConfig, + } + + e.Logger.Infof("Starting secure Galadriel Server TCP listening on %s", e.TCPAddress.String()) errChan := make(chan error) go func() { - errChan <- server.Start(e.TCPAddress.String()) + e.triggerListeningHook() + // certificate and key are embedded in the TLS config + errChan <- httpServer.ListenAndServeTLS("", "") }() - var err error + go e.startTLSCertificateRotation(ctx, errChan) + select { case err = <-errChan: e.Logger.WithError(err).Error("TCP Server stopped prematurely") return err case <-ctx.Done(): e.Logger.Info("Stopping TCP Server") - server.Close() + err = httpServer.Close() + if err != nil { + e.Logger.WithError(err).Error("Error closing HTTP TCP Server") + } + err = server.Close() + if err != nil { + e.Logger.WithError(err).Error("Error closing Echo Server") + } <-errChan e.Logger.Info("TCP Server stopped") return nil @@ -134,3 +183,73 @@ func (e *Endpoints) addTCPHandlers(server *echo.Echo) { server.POST("/bundle", e.postBundleHandler) server.POST("/bundle/sync", e.syncFederatedBundleHandler) } + +func (t *certificateSource) setTLSCertificate(cert *tls.Certificate) { + t.mu.Lock() + defer t.mu.Unlock() + t.cert = cert +} + +func (t *certificateSource) getTLSCertificate() *tls.Certificate { + t.mu.RLock() + defer t.mu.RUnlock() + return t.cert +} + +func (e *Endpoints) startTLSCertificateRotation(ctx context.Context, errChan chan error) { + e.Logger.Info("Starting TLS certificate rotator") + + // Start a ticker that rotates the certificate every default interval + certRotationInterval := defaultTTL / 2 + ticker := time.NewTicker(certRotationInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + e.Logger.Info("Rotating Server TLS certificate") + cert, err := e.getTLSCertificate(ctx) + if err != nil { + errChan <- fmt.Errorf("failed to rotate Server TLS certificate: %w", err) + } + e.certsStore.setTLSCertificate(cert) + case <-ctx.Done(): + e.Logger.Info("Stopped Server TLS certificate rotator") + return + } + } +} + +func (e *Endpoints) getTLSCertificate(ctx context.Context) (*tls.Certificate, error) { + privateKey, err := cryptoutil.GenerateSigner(cryptoutil.RSA2048) + if err != nil { + return nil, fmt.Errorf("failed to create private key: %w", err) + } + + params := &x509ca.X509CertificateParams{ + Subject: pkix.Name{ + CommonName: constants.GaladrielServerName, + }, + TTL: defaultTTL, + PublicKey: privateKey.Public(), + DNSNames: []string{constants.GaladrielServerName}, + } + cert, err := e.x509CA.IssueX509Certificate(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to issue TLS certificate: %w", err) + } + + certPEM := cryptoutil.EncodeCertificate(cert[0]) + keyPEM := cryptoutil.EncodeRSAPrivateKey(privateKey.(*rsa.PrivateKey)) + + certificate, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, err + } + return &certificate, nil +} + +func (e *Endpoints) triggerListeningHook() { + if e.hooks.tcpListening != nil { + e.hooks.tcpListening <- struct{}{} + } +} diff --git a/pkg/server/endpoints/run_test.go b/pkg/server/endpoints/run_test.go index e593e2998..6b9c24ae0 100644 --- a/pkg/server/endpoints/run_test.go +++ b/pkg/server/endpoints/run_test.go @@ -1 +1,95 @@ -package endpoints_test +package endpoints + +import ( + "context" + "net" + "path/filepath" + "testing" + "time" + + "github.com/HewlettPackard/galadriel/pkg/common/x509ca" + "github.com/HewlettPackard/galadriel/pkg/common/x509ca/disk" + "github.com/HewlettPackard/galadriel/test/certtest" + "github.com/jmhodges/clock" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeCatalog struct { + x509ca x509ca.X509CA +} + +func (c fakeCatalog) GetX509CA() x509ca.X509CA { + return c.x509ca +} + +func TestListenAndServe(t *testing.T) { + config := newEndpointTestConfig(t) + + endpoints, err := New(config) + require.NoError(t, err) + + endpoints.hooks.tcpListening = make(chan struct{}) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + errCh := make(chan error) + go func() { + errCh <- endpoints.ListenAndServe(ctx) + }() + defer func() { + cancel() + assert.NoError(t, <-errCh) + }() + + waitForListening(t, endpoints, errCh) +} + +func newEndpointTestConfig(t *testing.T) *Config { + // used to generate a TCP address with a random port + listener, err := net.ListenTCP("tcp", &net.TCPAddr{}) + require.NoError(t, err) + err = listener.Close() + require.NoError(t, err) + + tempDir := t.TempDir() + tcpAddr := listener.Addr().(*net.TCPAddr) + localAddr := &net.UnixAddr{Net: "unix", Name: filepath.Join(tempDir, "sockets")} + logger, _ := test.NewNullLogger() + + clk := clock.NewFake() + clk.Set(time.Now()) + + certsFolder := certtest.CreateTestCACertificates(t, clk) + + ca, err := disk.New() + require.NoError(t, err) + c := &disk.Config{ + CertFilePath: certsFolder + "/root-ca.crt", + KeyFilePath: certsFolder + "/root-ca.key", + } + err = ca.Configure(c) + require.NoError(t, err) + + cat := fakeCatalog{ + x509ca: ca, + } + + config := &Config{ + TCPAddress: tcpAddr, + LocalAddress: localAddr, + Logger: logger, + Catalog: cat, + } + + return config +} + +func waitForListening(t *testing.T, e *Endpoints, errCh chan error) { + select { + case <-e.hooks.tcpListening: + case err := <-errCh: + t.Fatalf("Failed to start Endpoints: %v", err) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b62b54ed0..346241c8d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -6,6 +6,7 @@ import ( "github.com/HewlettPackard/galadriel/pkg/common/telemetry" "github.com/HewlettPackard/galadriel/pkg/common/util" "github.com/HewlettPackard/galadriel/pkg/server/catalog" + "github.com/HewlettPackard/galadriel/pkg/server/datastore" "github.com/HewlettPackard/galadriel/pkg/server/endpoints" ) @@ -34,7 +35,13 @@ func (s *Server) run(ctx context.Context) error { return err } - endpointsServer, err := s.newEndpointsServer(cat) + // TODO: consider moving the datastore to the catalog? + ds, err := datastore.NewSQLDatastore(s.config.Logger, s.config.DBConnString) + if err != nil { + return err + } + + endpointsServer, err := s.newEndpointsServer(cat, ds) if err != nil { return err } @@ -46,13 +53,13 @@ func (s *Server) run(ctx context.Context) error { return err } -func (s *Server) newEndpointsServer(catalog catalog.Catalog) (endpoints.Server, error) { +func (s *Server) newEndpointsServer(catalog catalog.Catalog, ds datastore.Datastore) (endpoints.Server, error) { config := &endpoints.Config{ - TCPAddress: s.config.TCPAddress, - LocalAddress: s.config.LocalAddress, - DatastoreConnString: s.config.DBConnString, - Logger: s.config.Logger.WithField(telemetry.SubsystemName, telemetry.Endpoints), - Catalog: catalog, + TCPAddress: s.config.TCPAddress, + LocalAddress: s.config.LocalAddress, + Logger: s.config.Logger.WithField(telemetry.SubsystemName, telemetry.Endpoints), + Datastore: ds, + Catalog: catalog, } return endpoints.New(config)