From 1588cc830af57d3604ca4475d17004c387c9a04d Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Thu, 5 Dec 2024 11:29:58 -0600 Subject: [PATCH] add support for SRV discovery for permissions-api host (#98) To better support failover to other regions without adding load balancer hops and latency, permissions hosts can now support SRV record discovery to discover additional hosts which can serve requests. SRV records are looked up host configured for the permissions client. The SRV service looked up is `permissions-api` with protocol `tcp`. An example SRV lookup request would be for `_permissions-api._tcp.iam.example.com`. Where `iam.example.com` is the host configured for `permissions.host`. For best backwards compatibility, these SRV records are optional and will fallback to using the value provided in `permissions.host`. Additionally, to support retrying on failure, the permissions client was updated to support retrying auth checks if the response was not successful. This ensures a seamless transition when a host has failed between health checks. Signed-off-by: Mike Mason --- chart/iam-runtime-infratographer/README.md | 13 + chart/iam-runtime-infratographer/values.yaml | 31 ++ config.example.yaml | 15 + go.mod | 2 + go.sum | 8 + internal/permissions/client.go | 23 +- internal/permissions/config.go | 158 ++++++ internal/permissions/logger.go | 32 ++ internal/selecthost/doc.go | 5 + internal/selecthost/helpers.go | 99 ++++ internal/selecthost/helpers_test.go | 185 +++++++ internal/selecthost/host.go | 345 ++++++++++++ internal/selecthost/host_test.go | 402 ++++++++++++++ internal/selecthost/http.go | 78 +++ internal/selecthost/http_test.go | 124 +++++ internal/selecthost/options.go | 149 +++++ internal/selecthost/selector.go | 501 +++++++++++++++++ internal/selecthost/selector_test.go | 540 +++++++++++++++++++ 18 files changed, 2703 insertions(+), 7 deletions(-) create mode 100644 internal/permissions/logger.go create mode 100644 internal/selecthost/doc.go create mode 100644 internal/selecthost/helpers.go create mode 100644 internal/selecthost/helpers_test.go create mode 100644 internal/selecthost/host.go create mode 100644 internal/selecthost/host_test.go create mode 100644 internal/selecthost/http.go create mode 100644 internal/selecthost/http_test.go create mode 100644 internal/selecthost/options.go create mode 100644 internal/selecthost/selector.go create mode 100644 internal/selecthost/selector_test.go diff --git a/chart/iam-runtime-infratographer/README.md b/chart/iam-runtime-infratographer/README.md index 494dbc1a..73af5252 100644 --- a/chart/iam-runtime-infratographer/README.md +++ b/chart/iam-runtime-infratographer/README.md @@ -71,6 +71,19 @@ iam-runtime-infratographer: | config.events.nats.url | string | `""` | url NATS server url to use. | | config.jwt.issuer | string | `""` | issuer Issuer to use for JWT validation. | | config.jwt.jwksURI | string | `""` | jwksURI JWKS URI to use for JWT validation. | +| config.permissions.discovery.check.concurrency | int | `5` | concurrency is the number of hosts to concurrently check. | +| config.permissions.discovery.check.count | int | `5` | count is the number of checks to run on each host to check for connection latency. | +| config.permissions.discovery.check.delay | string | `"200ms"` | delay is the delay between requests for a host. | +| config.permissions.discovery.check.interval | string | `"1m"` | interval is how frequent to check for healthiness on hosts. | +| config.permissions.discovery.check.path | string | `"/readyz"` | path is the uri path to fetch to check if host is healthy. | +| config.permissions.discovery.check.scheme | string | `""` | scheme sets the uri scheme. Default is http unless discovered port is 443 in which https will be used. | +| config.permissions.discovery.check.timeout | string | `"2s"` | timeout sets the maximum amount of time a request can wait before canceling the request. | +| config.permissions.discovery.disable | bool | `false` | disable SRV discovery. | +| config.permissions.discovery.fallback | string | `""` | fallback sets the fallback address if no hosts are found or all hosts are unhealthy. The default fallback host is the permissions.host value. | +| config.permissions.discovery.interval | string | `"15m"` | interval to check for new SRV records. | +| config.permissions.discovery.optional | bool | `true` | optional allows SRV records to be optional. If no SRV records are found or all endpoints are unhealthy, the fallback host is used. | +| config.permissions.discovery.prefer | string | `""` | prefer sets the preferred SRV record. (skips priority, weight and duration ordering) | +| config.permissions.discovery.quick | bool | `false` | quick doesn't wait for discovery and health checks to complete before selecting a host. | | config.permissions.host | string | `""` | host permissions-api host to use. | | config.tracing.enabled | bool | `false` | enabled initializes otel tracing. | | config.tracing.insecure | bool | `false` | insecure if TLS should be disabled. | diff --git a/chart/iam-runtime-infratographer/values.yaml b/chart/iam-runtime-infratographer/values.yaml index 32e3cbd6..b7f32f56 100644 --- a/chart/iam-runtime-infratographer/values.yaml +++ b/chart/iam-runtime-infratographer/values.yaml @@ -16,6 +16,37 @@ config: permissions: # -- host permissions-api host to use. host: "" + + discovery: + # -- disable SRV discovery. + disable: false + # -- interval to check for new SRV records. + interval: 15m + # -- quick doesn't wait for discovery and health checks to complete before selecting a host. + quick: false + # -- optional allows SRV records to be optional. + # If no SRV records are found or all endpoints are unhealthy, the fallback host is used. + optional: true + # -- prefer sets the preferred SRV record. (skips priority, weight and duration ordering) + prefer: "" + # -- fallback sets the fallback address if no hosts are found or all hosts are unhealthy. + # The default fallback host is the permissions.host value. + fallback: "" + check: + # -- scheme sets the uri scheme. Default is http unless discovered port is 443 in which https will be used. + scheme: "" + # -- path is the uri path to fetch to check if host is healthy. + path: /readyz + # -- count is the number of checks to run on each host to check for connection latency. + count: 5 + # -- interval is how frequent to check for healthiness on hosts. + interval: 1m + # -- delay is the delay between requests for a host. + delay: 200ms + # -- timeout sets the maximum amount of time a request can wait before canceling the request. + timeout: 2s + # -- concurrency is the number of hosts to concurrently check. + concurrency: 5 events: # -- enabled enables NATS event-based functions. enabled: false diff --git a/config.example.yaml b/config.example.yaml index 4b90b99b..34527d04 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -2,6 +2,21 @@ server: socketpath: /tmp/runtime.sock permissions: host: permissions-api.enterprise.dev + discovery: + disable: false + interval: 15m + quick: false + optional: true + prefer: "" + fallback: "" + check: + scheme: "" + path: /readyz + count: 5 + interval: 1m + delay: 200ms + timeout: 2s + concurrency: 5 jwt: jwksuri: https://identity-api.enterprise.dev/jwks.json issuer: https://identity-api.enterprise.dev/ diff --git a/go.mod b/go.mod index d96dcb74..fd9c8ec6 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/MicahParks/keyfunc/v3 v3.3.3 github.com/go-jose/go-jose/v4 v4.0.4 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/hashicorp/go-retryablehttp v0.7.7 github.com/labstack/echo/v4 v4.12.0 github.com/metal-toolbox/iam-runtime v0.4.1 github.com/spf13/cobra v1.8.1 @@ -40,6 +41,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jaevor/go-nanoid v1.4.0 // indirect diff --git a/go.sum b/go.sum index 5800ce44..d906e137 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -36,6 +38,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/permissions/client.go b/internal/permissions/client.go index 03063769..e1f7f419 100644 --- a/internal/permissions/client.go +++ b/internal/permissions/client.go @@ -11,7 +11,8 @@ import ( "net/url" "time" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "github.com/hashicorp/go-retryablehttp" + "go.infratographer.com/iam-runtime-infratographer/internal/selecthost" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -54,7 +55,7 @@ type Client interface { type client struct { apiURL string - httpClient *http.Client + httpClient *retryablehttp.Client tracer trace.Tracer logger *zap.SugaredLogger } @@ -67,17 +68,25 @@ func NewClient(config Config, logger *zap.SugaredLogger) (Client, error) { return nil, err } - tracer := otel.GetTracerProvider().Tracer(tracerName) + transport, err := config.initTransport(http.DefaultTransport, selecthost.Logger(logger)) + if err != nil { + return nil, err + } + + httpClient := retryablehttp.NewClient() - httpClient := &http.Client{ + httpClient.RetryWaitMin = 100 * time.Millisecond + httpClient.RetryWaitMax = 2 * time.Second + httpClient.Logger = &retryableLogger{logger} + httpClient.HTTPClient = &http.Client{ Timeout: clientTimeout, - Transport: otelhttp.NewTransport(http.DefaultTransport), + Transport: transport, } out := &client{ apiURL: apiURLString, httpClient: httpClient, - tracer: tracer, + tracer: otel.GetTracerProvider().Tracer(tracerName), logger: logger, } @@ -118,7 +127,7 @@ func (c *client) CheckAccess(ctx context.Context, subjToken string, actions []Re } // Build the request to send up to permissions-api. - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, &reqBody) + req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, c.apiURL, &reqBody) if err != nil { span.SetStatus(codes.Error, err.Error()) c.logger.Errorw("failed to create permissions-api request", "error", err) diff --git a/internal/permissions/config.go b/internal/permissions/config.go index 0d2ebb8e..c6bf5927 100644 --- a/internal/permissions/config.go +++ b/internal/permissions/config.go @@ -1,13 +1,171 @@ package permissions import ( + "net/http" + "time" + "github.com/spf13/pflag" + "go.infratographer.com/iam-runtime-infratographer/internal/selecthost" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) // Config represents a permissions-api client configuration. type Config struct { // Host represents a permissions-api host to hit. Host string + + // Discovery defines the host discovery configuration. + Discovery DiscoveryConfig +} + +func (c Config) initTransport(base http.RoundTripper, opts ...selecthost.Option) (http.RoundTripper, error) { + base = otelhttp.NewTransport(base) + + if c.Discovery.Disable { + return base, nil + } + + cOpts := []selecthost.Option{ + selecthost.Fallback(c.Host), + } + + discovery := c.Discovery + + if discovery.Interval > 0 { + cOpts = append(cOpts, selecthost.DiscoveryInterval(discovery.Interval)) + } + + if discovery.Quick != nil && *discovery.Quick { + cOpts = append(cOpts, selecthost.Quick()) + } + + if discovery.Optional == nil || *discovery.Optional { + cOpts = append(cOpts, selecthost.Optional()) + } + + if discovery.Prefer != "" { + cOpts = append(cOpts, selecthost.Prefer(discovery.Prefer)) + } + + if discovery.Fallback != "" { + cOpts = append(cOpts, selecthost.Fallback(discovery.Fallback)) + } + + check := discovery.Check + + if check.Scheme != "" { + cOpts = append(cOpts, selecthost.CheckScheme(check.Scheme)) + } + + if check.Path != "" { + cOpts = append(cOpts, selecthost.CheckPath(check.Path)) + } else { + cOpts = append(cOpts, selecthost.CheckPath("/readyz")) + } + + if check.Count > 0 { + cOpts = append(cOpts, selecthost.CheckCount(check.Count)) + } + + if check.Interval > 0 { + cOpts = append(cOpts, selecthost.CheckInterval(check.Interval)) + } + + if check.Delay > 0 { + cOpts = append(cOpts, selecthost.CheckDelay(check.Delay)) + } + + if check.Timeout > 0 { + cOpts = append(cOpts, selecthost.CheckTimeout(check.Timeout)) + } + + if check.Concurrency > 0 { + cOpts = append(cOpts, selecthost.CheckConcurrency(check.Concurrency)) + } + + selector, err := selecthost.NewSelector(c.Host, "permissions-api", "tcp", append(cOpts, opts...)...) + if err != nil { + return nil, err + } + + selector.Start() + + return selecthost.NewTransport(selector, base), nil +} + +// DiscoveryConfig represents the host discovery configuration. +type DiscoveryConfig struct { + // Disable disables host discovery. + // + // Default: false + Disable bool + + // Interval sets the frequency at which SRV records are rediscovered. + // + // Default: 15m + Interval time.Duration + + // Quick ensures a quick startup, allowing for a more optimal host to be chosen after discovery has occurred. + // When Quick is enabled, the default fallback address or default host is immediately returned. + // Once the discovery process has completed, a discovered host will be selected. + // + // Default: false + Quick *bool + + // Optional uses the fallback address or default host without throwing errors. + // The discovery process continues to run in the background, in the chance that SRV records are added at a later point. + // + // Default: true + Optional *bool + + // Check customizes the target health checking process. + Check CheckConfig + + // Prefer specifies a preferred host. + // If the host is not discovered or has an error, it will not be used. + Prefer string + + // Fallback specifies a fallback host if no hosts are discovered or all hosts are currently failing. + // + // Default: [Config] Host + Fallback string +} + +// CheckConfig defines the configuration for host checks. +type CheckConfig struct { + // Scheme sets the check URI scheme. + // Default is http unless discovered host port is 443 in which scheme is th en https + Scheme string + + // Path sets the request path for checks. + // + // Default: /readyz + Path string + + // Count defines the number of checks to run on each endpoint. + // + // Default: 5 + Count int + + // Interval specifies how frequently to run checks. + // + // Default: 1m + Interval time.Duration + + // Delay specifies how long to wait between subsequent checks for the same host. + // + // Default: 200ms + Delay time.Duration + + // Timeout defines the maximum time an individual check request can take. + // + // Default: 2s + Timeout time.Duration + + // Concurrency defines the number of hosts which may be checked simultaneously. + // + // Default: 5 + Concurrency int } // AddFlags sets the command line flags for the permissions-api client. diff --git a/internal/permissions/logger.go b/internal/permissions/logger.go new file mode 100644 index 00000000..a4d433e1 --- /dev/null +++ b/internal/permissions/logger.go @@ -0,0 +1,32 @@ +package permissions + +import ( + "github.com/hashicorp/go-retryablehttp" + "go.uber.org/zap" +) + +var _ retryablehttp.LeveledLogger = (*retryableLogger)(nil) + +type retryableLogger struct { + logger *zap.SugaredLogger +} + +// Error implements retryablehttp.LeveledLogger +func (l *retryableLogger) Error(msg string, keysAndValues ...any) { + l.logger.Errorw(msg, keysAndValues...) +} + +// Info implements retryablehttp.LeveledLogger +func (l *retryableLogger) Info(msg string, keysAndValues ...any) { + l.logger.Infow(msg, keysAndValues...) +} + +// Debug implements retryablehttp.LeveledLogger +func (l *retryableLogger) Debug(msg string, keysAndValues ...any) { + l.logger.Debugw(msg, keysAndValues...) +} + +// Warn implements retryablehttp.LeveledLogger +func (l *retryableLogger) Warn(msg string, keysAndValues ...any) { + l.logger.Warnw(msg, keysAndValues...) +} diff --git a/internal/selecthost/doc.go b/internal/selecthost/doc.go new file mode 100644 index 00000000..518cdebb --- /dev/null +++ b/internal/selecthost/doc.go @@ -0,0 +1,5 @@ +// Package selecthost handles host discovery via DNS SRV records, keeps track of healthy +// and selects the most optimal host for use. +// +// An HTTP [Transport] is provided which simplifies using this package with any http client. +package selecthost diff --git a/internal/selecthost/helpers.go b/internal/selecthost/helpers.go new file mode 100644 index 00000000..deafdfbf --- /dev/null +++ b/internal/selecthost/helpers.go @@ -0,0 +1,99 @@ +package selecthost + +import ( + "fmt" + "net" + "strconv" + "strings" + "time" +) + +var ( + // ErrHostRemoved is set on a host if it no longer exists in the discovered hosts. + ErrHostRemoved = fmt.Errorf("%w: host removed", ErrSelectHost) +) + +// JoinHostPort combines host and port into a network address of the form "host:port". +// If host contains a colon, as found in literal IPv6 addresses, then JoinHostPort returns "[host]:port". +// If port is not defined, port is left out. +func JoinHostPort(host string, port string) string { + if strings.ContainsRune(host, ':') { + host = "[" + host + "]" + } + + if port != "" { + host += ":" + port + } + + return host +} + +func hostID(host string, port, priority, weight uint16) string { + parts := []string{ + strings.TrimRight(host, "."), + strconv.FormatUint(uint64(port), 10), + strconv.FormatUint(uint64(priority), 10), + strconv.FormatUint(uint64(weight), 10), + } + + return strings.Join(parts, ":") +} + +// ParseHost parses the provided host allowing for a host without a port and returns a new [Host]. +func ParseHost(selector *Selector, host string) (Host, error) { + var port string + + shost, sport, err := net.SplitHostPort(host) + if err != nil { + if addrErr, ok := err.(*net.AddrError); !ok || addrErr.Err != "missing port in address" { + return nil, err + } + } else { + host = shost + port = sport + } + + return newHost(selector, host, port, net.SRV{}), nil +} + +func diffHosts(s *Selector, srvs []*net.SRV) ([]Host, []Host, Hosts) { + var ( + trackedHosts = make(map[string]Host, len(s.hosts)) + srvTargets = make(map[string]*net.SRV, len(srvs)) + + addedHosts = make([]Host, 0) + matchedHosts = make(Hosts, len(srvs)) + removedHosts = make([]Host, 0) + ) + + for _, host := range s.hosts { + trackedHosts[host.ID()] = host + } + + for i, srv := range srvs { + srvKey := hostID(srv.Target, srv.Port, srv.Priority, srv.Weight) + + srvTargets[srvKey] = srv + + host := trackedHosts[srvKey] + if host != nil { + matchedHosts[i] = host + } else { + matchedHosts[i] = newHost(s, srv.Target, strconv.FormatUint(uint64(srv.Port), 10), *srv) + addedHosts = append(addedHosts, matchedHosts[i]) + } + } + + for _, host := range s.hosts { + if _, ok := srvTargets[host.ID()]; !ok { + host.setError(ErrHostRemoved) + removedHosts = append(removedHosts, host) + } + } + + return addedHosts, removedHosts, matchedHosts +} + +func toMilliseconds(duration time.Duration) float64 { + return float64(duration) / float64(time.Millisecond) +} diff --git a/internal/selecthost/helpers_test.go b/internal/selecthost/helpers_test.go new file mode 100644 index 00000000..ab0c8b06 --- /dev/null +++ b/internal/selecthost/helpers_test.go @@ -0,0 +1,185 @@ +package selecthost + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJoinHostPort(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + host string + port string + expect string + }{ + {"dns host only", "example.com", "", "example.com"}, + {"dns with port", "example.com", "8080", "example.com:8080"}, + {"ipv4 host only", "1.2.3.4", "", "1.2.3.4"}, + {"ipv4 with port", "1.2.3.4", "8080", "1.2.3.4:8080"}, + {"ipv6 host only", "::1", "", "[::1]"}, + {"ipv6 withi port", "::1", "8080", "[::1]:8080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := JoinHostPort(tc.host, tc.port) + + assert.Equal(t, tc.expect, result, "unexpected joined host and port") + }) + } +} + +func TestHostID(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + host string + port uint16 + expect string + }{ + {"host only", "example.com", 0, "example.com:0:0:0"}, + {"with port", "example.com", 8080, "example.com:8080:0:0"}, + {"with trailing period", "example.com.", 8080, "example.com:8080:0:0"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := hostID(tc.host, tc.port, 0, 0) + + assert.Equal(t, tc.expect, result, "unexpected host id") + }) + } +} + +func TestParseHost(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + host string + expectHost string + expectPort string + expectError string + }{ + {"dns host only", "example.com", "example.com", "", ""}, + {"dns with port", "example.com:8080", "example.com", "8080", ""}, + {"invalid", "[1:2", "", "", "missing ']' in address"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + host, err := ParseHost(nil, tc.host) + + if tc.expectError != "" { + require.Error(t, err, "error expected") + + assert.ErrorContains(t, err, tc.expectError, "unexpected error returned") + + return + } + + require.NoError(t, err, "no error expected") + require.NotNil(t, host, "host expected to be returned") + + assert.Equal(t, tc.expectHost, host.Host(), "unexpected host") + assert.Equal(t, tc.expectPort, host.Port(), "unexpected host port") + }) + } +} + +// nolint: govet +func TestDiffHosts(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + startingHosts Hosts + srvs []*net.SRV + expectAdded []string + expectRemoved []string + expectMatched []string + }{ + { + "empty", + nil, + nil, + []string{}, []string{}, []string{}, + }, + { + "all added", + nil, + []*net.SRV{ + {"one.example.com", 0, 10, 20}, + {"two.example.com", 80, 10, 20}, + }, + []string{"one.example.com:0:10:20", "two.example.com:80:10:20"}, + []string{}, + []string{"one.example.com:0:10:20", "two.example.com:80:10:20"}, + }, + { + "all removed", + Hosts{ + &host{id: "one.example.com:0:10:20"}, + &host{id: "two.example.com:80:10:20"}, + }, + []*net.SRV{}, + []string{}, + []string{"one.example.com:0:10:20", "two.example.com:80:10:20"}, + []string{}, + }, + { + "some removed, some added", + Hosts{ + &host{id: "one.example.com:0:10:20"}, + &host{id: "two.example.com:80:10:20"}, + }, + []*net.SRV{ + {"one.example.com", 0, 10, 20}, + {"three.example.com", 8080, 10, 20}, + }, + []string{"three.example.com:8080:10:20"}, + []string{"two.example.com:80:10:20"}, + []string{"one.example.com:0:10:20", "three.example.com:8080:10:20"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s := &Selector{hosts: tc.startingHosts} + + added, removed, matched := diffHosts(s, tc.srvs) + + addedIDs := getHostIDs(added) + removedIDs := getHostIDs(removed) + matchedIDs := getHostIDs(matched) + + assert.Equal(t, tc.expectAdded, addedIDs, "unexpected added hosts") + assert.Equal(t, tc.expectRemoved, removedIDs, "unexpected removed hosts") + assert.Equal(t, tc.expectMatched, matchedIDs, "unexpected matched hosts") + }) + } +} + +func getHostIDs(hosts []Host) []string { + ids := make([]string, len(hosts)) + + for i, host := range hosts { + ids[i] = host.ID() + } + + return ids +} diff --git a/internal/selecthost/host.go b/internal/selecthost/host.go new file mode 100644 index 00000000..432221fa --- /dev/null +++ b/internal/selecthost/host.go @@ -0,0 +1,345 @@ +package selecthost + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-cleanhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +var ( + // ErrUnexpectedStatusCode is returned when a check request doesn't have 2xx status code. + ErrUnexpectedStatusCode = fmt.Errorf("%w: unexpected status code", ErrSelectHost) +) + +// httpClient sets the client timeout to 10 seconds. +var httpClient = &http.Client{ + Transport: otelhttp.NewTransport(cleanhttp.DefaultPooledTransport()), + Timeout: 10 * time.Second, +} + +// Host is an individual host entry. +type Host interface { + ID() string + Record() net.SRV + Host() string + Port() string + + Before(h2 Host) bool + + check(ctx context.Context) + LastCheck() Results + AverageDuration() time.Duration + Err() error + setError(err error) +} + +// Hosts is a collection of [Host]s. +type Hosts []Host + +// Len implement sort.Interface. +func (h Hosts) Len() int { return len(h) } + +// Less implements sort.Interface. +func (h Hosts) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +// Swap implement sort.Interface. +func (h Hosts) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func newHost(selector *Selector, target string, port string, srv net.SRV) Host { + target = strings.TrimRight(target, ".") + iport, _ := strconv.ParseUint(port, 10, 16) // no error check, if it's an invalid port we'll use the default 0 + + return &host{ + selector: selector, + id: hostID(target, uint16(iport), srv.Priority, srv.Weight), + record: srv, + host: target, + port: port, + } +} + +type host struct { + selector *Selector + + mu sync.RWMutex + + id string + record net.SRV + host string + port string + + lastCheck Results + err error +} + +// ID returns the host ID. +func (h *host) ID() string { + return h.id +} + +// Record returns the SRV records. +func (h *host) Record() net.SRV { + return h.record +} + +// Host returns the host. +func (h *host) Host() string { + return h.host +} + +// Port returns the port. +func (h *host) Port() string { + return h.port +} + +// LastCheck returns the latest check results. +func (h *host) LastCheck() Results { + h.mu.RLock() + defer h.mu.RUnlock() + + return h.lastCheck +} + +// AverageDuration returns the latest check's average duration. +func (h *host) AverageDuration() time.Duration { + return h.LastCheck().Average() +} + +// Err returns the error recorded on the host. +func (h *host) Err() error { + h.mu.RLock() + defer h.mu.RUnlock() + + return h.err +} + +func (h *host) setError(err error) { + h.mu.Lock() + defer h.mu.Unlock() + + h.err = err +} + +// Before compares if the left host is before the provided host. +func (h *host) Before(h2 Host) bool { + switch { + case h.Err() == nil && h2.Err() != nil: + return true + case h.Err() != nil && h2.Err() == nil: + return false + case h.Record().Priority < h2.Record().Priority: + return true + case h.Record().Priority > h2.Record().Priority: + return false + case h.Record().Weight < h2.Record().Weight: + return true + case h.Record().Weight > h2.Record().Weight: + return false + case h.AverageDuration() < h2.AverageDuration(): + return true + } + + return false +} + +func (h *host) check(ctx context.Context) { + if ctx == nil { + ctx = context.Background() + } + + ctx, span := tracer.Start(ctx, "host.check", trace.WithAttributes( + attribute.String("check.host", h.id), + attribute.Int("check.total", h.selector.checkCount), + attribute.Float64("check.delay_ms", float64(h.selector.checkDelay)/float64(time.Millisecond)), + )) + + defer span.End() + + results := Results{ + Time: time.Now(), + Host: h, + } + + cancel := func() {} + + logger := h.selector.logger.With( + "check.total", h.selector.checkCount, + "check.host", h.id, + ) + + logger.Debugf("Starting host checks for '%s'", h.id) + +checkLoop: + for i := 0; i < h.selector.checkCount; i++ { + if i != 0 { + cancel() + + select { + case <-ctx.Done(): + break checkLoop + case <-h.selector.runCh: + break checkLoop + default: + } + + time.Sleep(h.selector.checkDelay) + } + + cctx, ccancel := context.WithTimeoutCause(ctx, h.selector.checkTimeout, ErrHostCheckTimedout) + + cancel = ccancel + + // If the manager is stopped, cancel the context, also exit if the context gets canceled. + go func() { + select { + case <-cctx.Done(): + case <-h.selector.runCh: + ccancel() + } + }() + + duration, err := h.run(cctx, logger.With("check.run", i)) + if err != nil { + results.Errors = append(results.Errors, err) + } + + results.Checks++ + results.TotalDuration += duration + } + + cancel() + + span.SetAttributes(attribute.Float64("check.average_ms", toMilliseconds(results.Average()))) + + logger = logger.With("check.average_ms", toMilliseconds(results.Average())) + + if len(results.Errors) != 0 { + errs := errors.Join(results.Errors...) + span.RecordError(errs) + span.SetStatus(codes.Error, errs.Error()) + + logger.Errorw("Host checks completed with errors", "errors", results.Errors) + } else { + logger.Debugf("Host checks completed without errors") + } + + h.mu.Lock() + defer h.mu.Unlock() + + h.lastCheck = results + + if len(results.Errors) != 0 { + h.err = results.Errors[len(results.Errors)-1] + } else { + h.err = nil + } +} + +func (h *host) buildURI() *url.URL { + scheme := "http" + + if h.port == "443" { + scheme = "https" + } + + if h.selector.checkScheme != "" { + scheme = h.selector.checkScheme + } + + port := h.port + + // Strip port for default scheme ports. + if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") { + port = "" + } + + return &url.URL{ + Scheme: scheme, + Host: JoinHostPort(h.host, port), + Path: h.selector.checkPath, + } +} + +func (h *host) run(ctx context.Context, logger *zap.SugaredLogger) (time.Duration, error) { + uri := h.buildURI() + + logger = logger.With( + "check.uri", uri.String(), + ) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil) + if err != nil { + logger.Errorw("Failed to create check request", "error", err) + + return 0, err + } + + start := time.Now() + + resp, err := httpClient.Do(req) + if err != nil { + duration := time.Since(start) + + logger.Errorw("Failed to execute check request", "error", err, "check.duration_ms", toMilliseconds(duration)) + + return duration, err + } + + defer resp.Body.Close() + + // Consume body so connection can be reused. + // If an error occurs reading the body, ignore. + body, _ := io.ReadAll(resp.Body) + + duration := time.Since(start) + + logger = logger.With( + "check.duration_ms", toMilliseconds(duration), + "check.response.status_code", resp.StatusCode, + ) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + logger.Errorw("Check completed with an unexpected status code", "error", err) + + return duration, fmt.Errorf("%w: %d: %s", ErrUnexpectedStatusCode, resp.StatusCode, string(body)) + } + + logger.Debug("Check completed successfully") + + return duration, nil +} + +// Results holds host check results. +type Results struct { + Time time.Time + Host Host + Checks uint + TotalDuration time.Duration + Errors []error +} + +// Average returns the average duration of all checks run on host. +func (r Results) Average() time.Duration { + if r.Checks == 0 { + return 0 + } + + return r.TotalDuration / time.Duration(r.Checks) +} diff --git a/internal/selecthost/host_test.go b/internal/selecthost/host_test.go new file mode 100644 index 00000000..838c6e8f --- /dev/null +++ b/internal/selecthost/host_test.go @@ -0,0 +1,402 @@ +package selecthost + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// nolint: govet +func TestNewHost(t *testing.T) { + t.Parallel() + + record := net.SRV{"srv target", 10, 20, 30} + + hostI := newHost(&Selector{}, "target", "port", record) + + require.NotNil(t, hostI, "expected host") + + host, ok := hostI.(*host) + + require.True(t, ok, "expected new host to be of *host type") + assert.Equal(t, "target:0:20:30", host.id, "unexpected id") + assert.Equal(t, "target", host.host, "unexpected host") + assert.Equal(t, "port", host.port, "unexpected port") + assert.Equal(t, record, host.record, "unexpected record") +} + +func TestHostBefore(t *testing.T) { + t.Parallel() + + hostWithoutErr := testHost(nil, "zero", "80", 10, 10, nil, 10) + hostLowWeight := testHost(nil, "one", "80", 10, 10, nil, 10) + hostHighWeight := testHost(nil, "two", "80", 20, 10, nil, 10) + hostLowPrio := testHost(nil, "three", "80", 10, 10, nil, 10) + hostHighPrio := testHost(nil, "four", "80", 10, 20, nil, 10) + hostLowAvg := testHost(nil, "five", "80", 10, 10, nil, 10) + hostHighAvg := testHost(nil, "six", "80", 10, 10, nil, 20) + + hostWithErr := testHost(nil, "zero-err", "80", 10, 10, net.ErrClosed, 10) + + testCases := []struct { + name string + left Host + right Host + expectBefore bool + }{ + {"same", hostLowWeight, hostLowWeight, false}, + {"error not before", hostWithErr, hostWithoutErr, false}, + {"no error before error", hostWithoutErr, hostWithErr, true}, + {"both error not before", hostWithErr, hostWithErr, false}, + {"low priority first", hostLowPrio, hostHighPrio, true}, + {"high priority not before", hostHighPrio, hostLowPrio, false}, + {"equal priority not before", hostLowPrio, hostLowPrio, false}, + {"low weight first", hostLowWeight, hostHighWeight, true}, + {"high weight not before", hostHighWeight, hostLowWeight, false}, + {"equal weight not before", hostLowWeight, hostLowWeight, false}, + {"low avg first", hostLowAvg, hostHighAvg, true}, + {"high avg not before", hostHighAvg, hostLowAvg, false}, + {"equal avg not before", hostLowAvg, hostLowAvg, false}, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.left.Before(tc.right) + + assert.Equal(t, tc.expectBefore, result, "unexpected before result") + }) + } +} + +func TestHostBuildURI(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + host *host + expect string + }{ + { + "default http", + &host{ + selector: &Selector{}, + host: "1.2.3.4", + }, + "http://1.2.3.4", + }, + { + "default http default port", + &host{ + selector: &Selector{}, + host: "1.2.3.4", + port: "80", + }, + "http://1.2.3.4", + }, + { + "default http non standard port", + &host{ + selector: &Selector{}, + host: "1.2.3.4", + port: "8080", + }, + "http://1.2.3.4:8080", + }, + { + "default https default port", + &host{ + selector: &Selector{}, + host: "1.2.3.4", + port: "443", + }, + "https://1.2.3.4", + }, + { + "scheme http no port", + &host{ + selector: &Selector{checkScheme: "http"}, + host: "1.2.3.4", + port: "", + }, + "http://1.2.3.4", + }, + { + "scheme http with port", + &host{ + selector: &Selector{checkScheme: "http"}, + host: "1.2.3.4", + port: "80", + }, + "http://1.2.3.4", + }, + { + "scheme http alt port", + &host{ + selector: &Selector{checkScheme: "http"}, + host: "1.2.3.4", + port: "8080", + }, + "http://1.2.3.4:8080", + }, + { + "scheme https no port", + &host{ + selector: &Selector{checkScheme: "https"}, + host: "1.2.3.4", + }, + "https://1.2.3.4", + }, + { + "scheme https with port", + &host{ + selector: &Selector{checkScheme: "https"}, + host: "1.2.3.4", + port: "443", + }, + "https://1.2.3.4", + }, + { + "scheme https alt port", + &host{ + selector: &Selector{checkScheme: "https"}, + host: "1.2.3.4", + port: "8443", + }, + "https://1.2.3.4:8443", + }, + { + "with path", + &host{ + selector: &Selector{checkPath: "some/endpoint"}, + host: "1.2.3.4", + }, + "http://1.2.3.4/some/endpoint", + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.host.buildURI() + + assert.Equal(t, tc.expect, result.String(), "unexpected URI built") + }) + } +} + +func TestHostCheck(t *testing.T) { + t.Parallel() + + checkDelay := time.Millisecond + + testCases := []struct { + name string + responseDelay time.Duration + cancelDelay time.Duration + expectTotalDuration time.Duration + expectErrors int + }{ + { + "success", + 0, + 0, + 5 * checkDelay, + 0, + }, + { + "canceled", + 10 * time.Millisecond, + 20 * time.Millisecond, + 20 * time.Millisecond, + 1, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(tc.responseDelay) + })) + + defer srv.Close() + + h, p, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err, "no error expected splitting test server address") + + host := &host{ + selector: &Selector{ + logger: zap.NewNop().Sugar(), + checkDelay: checkDelay, + checkCount: 5, + checkTimeout: 100 * time.Millisecond, + runCh: make(chan struct{}), + }, + host: h, + port: p, + } + + ctx, cancel := context.WithCancel(context.Background()) + + defer cancel() + + if tc.cancelDelay != 0 { + go func() { + time.Sleep(tc.cancelDelay) + + cancel() + }() + } + + start := time.Now() + + host.check(ctx) + + totalDuration := time.Since(start) + + if tc.expectErrors != 0 { + require.Error(t, host.err, "expected error") + + assert.Equalf(t, tc.expectErrors, len(host.lastCheck.Errors), "expected %d errors, got %d | errors: %s", tc.expectErrors, len(host.lastCheck.Errors), errors.Join(host.lastCheck.Errors...)) + } else { + require.NoError(t, host.err, "no error expected") + } + + diff := totalDuration - tc.expectTotalDuration + assert.Truef(t, diff > 0 && diff < 10*time.Millisecond, "total duration unexpected. got: %s want: %s", totalDuration, tc.expectTotalDuration) + }) + } +} + +func TestHostRun(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + responseDelay time.Duration + cancelDelay time.Duration + responseStatusCode int + expectError error + }{ + { + "success", + 0, + 0, + http.StatusOK, + nil, + }, + { + "invalid response", + 0, + 0, + http.StatusInternalServerError, + ErrUnexpectedStatusCode, + }, + { + "canceled", + 500 * time.Millisecond, + 100 * time.Millisecond, + 0, + context.Canceled, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(tc.responseDelay) + + if tc.responseStatusCode != 0 { + w.WriteHeader(tc.responseStatusCode) + } + })) + + defer srv.Close() + + h, p, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err, "no error expected splitting test server address") + + host := &host{ + selector: &Selector{}, + host: h, + port: p, + } + + ctx, cancel := context.WithCancel(context.Background()) + + defer cancel() + + if tc.cancelDelay != 0 { + go func() { + time.Sleep(tc.cancelDelay) + + cancel() + }() + } + + duration, err := host.run(ctx, zap.NewNop().Sugar()) + + if tc.expectError != nil { + require.Error(t, err, "expected error") + + assert.ErrorIs(t, err, tc.expectError, "unexpected error returned") + + return + } + + require.NoError(t, err, "no error expected") + + if tc.responseDelay != 0 { + diff := duration - tc.responseDelay + assert.True(t, diff > 0 && diff < 10*time.Millisecond, "expected duration to be near response delay") + } + }) + } +} + +func testHost(selector *Selector, target, port string, weight, priority uint16, err error, average time.Duration) Host { + uport, _ := strconv.ParseUint(port, 10, 16) + + record := net.SRV{ + Target: target, + Port: uint16(uport), + Weight: weight, + Priority: priority, + } + + hostI := newHost(selector, target, port, record) + + host := hostI.(*host) + + host.err = err + + host.lastCheck = Results{ + Checks: 1, + TotalDuration: average, + } + + return host +} diff --git a/internal/selecthost/http.go b/internal/selecthost/http.go new file mode 100644 index 00000000..002c8da9 --- /dev/null +++ b/internal/selecthost/http.go @@ -0,0 +1,78 @@ +package selecthost + +import ( + "fmt" + "net" + "net/http" +) + +var _ http.RoundTripper = (*Transport)(nil) + +// Transport implements http.RoundTripper handles switching the request host to a discovered host. +type Transport struct { + Selector *Selector + Base http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +// If the request host matches the selector service's target, the host is replaced with the selected host address. +// If the request does not match, the base transport is called for the request instead. +// +// When the selected host is used, if the result from the base transport returns an error, +// the selected host is marked as having that error and a new host is immediately selected. +// The request however is not retried, instead the requestor must retry when appropriate. +// +// Hosts marked with an error will get cleared upon the next successful host check cycle. +func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { + basert := t.Base + if basert == nil { + basert = http.DefaultTransport + } + + host, err := t.Selector.GetHost(r.Context()) + if err != nil { + return nil, err + } + + if r.URL.Host != t.Selector.Target() { + return basert.RoundTrip(r) + } + + port := host.Port() + + // If no port was defined on the selected host, use the same port as the request, if one exists. + if port == "" { + _, port, _ = net.SplitHostPort(r.URL.Host) + } + + // Remove default ports from host + if (r.URL.Scheme == "http" && port == "80") || (r.URL.Scheme == "https" && port == "443") { + port = "" + } + + addr := JoinHostPort(host.Host(), port) + + r = r.Clone(r.Context()) + + r.URL.Host = addr + r.Host = addr + + resp, err := basert.RoundTrip(r) + if err != nil { + host.setError(err) + t.Selector.selectHost(r.Context()) + + return resp, fmt.Errorf("selected host: '%s': %w", addr, err) + } + + return resp, nil +} + +// NewTransport initialized a new Transport with the provided selector and base transport. +// If base is nil, the default http transport is used. +func NewTransport(selector *Selector, base http.RoundTripper) http.RoundTripper { + return &Transport{ + Selector: selector, + Base: base, + } +} diff --git a/internal/selecthost/http_test.go b/internal/selecthost/http_test.go new file mode 100644 index 00000000..7476e1eb --- /dev/null +++ b/internal/selecthost/http_test.go @@ -0,0 +1,124 @@ +package selecthost + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var errTestBase = errors.New("base error") + +type testTransport struct { + req *http.Request + err error +} + +func (t *testTransport) RoundTrip(r *http.Request) (*http.Response, error) { + t.req = r + + resp := &http.Response{ + Request: r, + Body: io.NopCloser(&bytes.Buffer{}), + } + + return resp, t.err +} + +func TestTransportRoundTrip(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + requestURL string + baseError bool + expectRequestURL string + expectError bool + expectSelected string + }{ + { + "success", + "http://host.example.com/test-path", + false, + "http://host1.example.com/test-path", + false, + "host1.example.com", + }, + { + "failed", + "http://host.example.com/test-path", + true, + "", + true, + "fallback.example.com", + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, tc.requestURL, nil) + require.NoError(t, err, "no error expected creating request") + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + target: "host.example.com", + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + selector.startOnce.Do(func() {}) + + defer close(selector.runCh) + + selected := &host{ + selector: selector, + host: "host1.example.com", + } + fallback := &host{ + selector: selector, + host: "fallback.example.com", + } + + selector.selected = selected + selector.fallback = fallback + + baseTransport := &testTransport{} + + if tc.baseError { + baseTransport.err = errTestBase + } + + transport := &Transport{ + Selector: selector, + Base: baseTransport, + } + + resp, err := transport.RoundTrip(req) + + if tc.expectError { + require.Error(t, err, "expected error to be returned") + + assert.Equal(t, tc.expectSelected, selector.selected.Host(), "unexpected host selected") + + return + } + + require.NoError(t, err, "no error expected to be returned") + + defer resp.Body.Close() + + assert.Equal(t, tc.expectRequestURL, resp.Request.URL.String(), "unexpected url requested") + assert.Equal(t, tc.expectSelected, selector.selected.Host(), "unexpected host selected") + }) + } +} diff --git a/internal/selecthost/options.go b/internal/selecthost/options.go new file mode 100644 index 00000000..4a0758b9 --- /dev/null +++ b/internal/selecthost/options.go @@ -0,0 +1,149 @@ +package selecthost + +import ( + "time" + + "go.uber.org/zap" +) + +// Option defines a selector option. +type Option func(s *Selector) error + +// Logger sets the logger. +func Logger(logger *zap.SugaredLogger) Option { + return func(s *Selector) error { + if logger != nil { + s.logger = logger + } + + return nil + } +} + +// DiscoveryInterval specifies the interval at which SRV records will be rediscovered. +// Default: 15m +func DiscoveryInterval(interval time.Duration) Option { + return func(s *Selector) error { + s.discoveryInterval = interval + + return nil + } +} + +// Quick will select the fallback address immediately on startup instead of waiting +// for the discovery process to complete. +func Quick() Option { + return func(s *Selector) error { + s.quick = true + + return nil + } +} + +// Optional if no SRV record is found, the target (or fallback address) is used instead. +// The discovery process continues to run in the chance that SRV records are added at a later point. +func Optional() Option { + return func(s *Selector) error { + s.optional = true + + return nil + } +} + +// CheckScheme sets the uri scheme. +// Default is http unless discovered port is 443, https will be used then. +func CheckScheme(scheme string) Option { + return func(s *Selector) error { + s.checkScheme = scheme + + return nil + } +} + +// CheckPath sets the request path for checks. +func CheckPath(path string) Option { + return func(s *Selector) error { + s.checkPath = path + + return nil + } +} + +// CheckCount defines how many checks to run on an endpoint. +// Default: 5 +func CheckCount(count int) Option { + return func(s *Selector) error { + s.checkCount = count + + return nil + } +} + +// CheckInterval specifies how frequently to run host checks. +// Default: 1m. +func CheckInterval(interval time.Duration) Option { + return func(s *Selector) error { + s.checkInterval = interval + + return nil + } +} + +// CheckDelay specifies how long to wait between subsequent checks for the same host. +// Default: 200ms +func CheckDelay(delay time.Duration) Option { + return func(s *Selector) error { + s.checkDelay = delay + + return nil + } +} + +// CheckTimeout defines the maximum time an individual check request can take. +// Default: 2s +func CheckTimeout(timeout time.Duration) Option { + return func(s *Selector) error { + s.checkTimeout = timeout + + return nil + } +} + +// CheckConcurrency defines the number of hosts which may be checked simultaneously. +// Default: 5 +func CheckConcurrency(count int) Option { + return func(s *Selector) error { + s.checkConcurrency = count + + return nil + } +} + +// Prefer specifies a preferred host. +// If the host is not discovered or has an error it will not be used. +func Prefer(host string) Option { + return func(s *Selector) error { + h, err := ParseHost(s, host) + if err != nil { + return err + } + + s.prefer = h + + return nil + } +} + +// Fallback specifies a fallback host if no hosts are discovered or all hosts are currently failing. +func Fallback(host string) Option { + return func(s *Selector) error { + h, err := ParseHost(s, host) + if err != nil { + return err + } + + s.fallback = h + + return nil + } +} diff --git a/internal/selecthost/selector.go b/internal/selecthost/selector.go new file mode 100644 index 00000000..4abf0ed3 --- /dev/null +++ b/internal/selecthost/selector.go @@ -0,0 +1,501 @@ +package selecthost + +import ( + "context" + "errors" + "fmt" + "net" + "sort" + "strings" + "sync" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +var ( + // ErrSelectHost is the root error for all SelectHost errors + ErrSelectHost = errors.New("SelectHost Error") + // ErrSelectorStopped is returned when waiting for a host but the selector has been stopped. + ErrSelectorStopped = fmt.Errorf("%w: selector stopped", ErrSelectHost) + // ErrHostNotFound is returned when a host is not able to be determined. + ErrHostNotFound = fmt.Errorf("%w: no host found", ErrSelectHost) + // ErrWaitTimedout is returned when waiting for a host to be discovered. + ErrWaitTimedout = fmt.Errorf("%w: timed out waiting for host", ErrSelectHost) + // ErrDiscoveryTimeout is returned when the discovery process takes longer than configured timeout. + ErrDiscoveryTimeout = fmt.Errorf("%w: discovery process timed out: %w", ErrSelectHost, context.DeadlineExceeded) + // ErrHostCheckTimedout is returned when the host check process takes longer than configured timeout. + ErrHostCheckTimedout = fmt.Errorf("%w: host check timed out: %w", ErrSelectHost, context.DeadlineExceeded) + + tracerName = "go.infratographer.com/iam-runtime-infratographer/internal/selecthost" + tracer = otel.GetTracerProvider().Tracer(tracerName) +) + +// Selector handles discovering SRV records, periodically polling and selecting +// the fastest responding endpoint. +type Selector struct { + logger *zap.SugaredLogger + + service string + protocol string + target string + + resolver interface { + LookupSRV(ctx context.Context, service, protocol, target string) (string, []*net.SRV, error) + } + discoveryInterval time.Duration + discoveryTimeout time.Duration + + checkScheme string + checkPath string + checkCount int + checkInterval time.Duration + checkDelay time.Duration + checkTimeout time.Duration + checkConcurrency int + + initTimeout time.Duration + + mu sync.RWMutex + startOnce sync.Once + quick bool + optional bool + + stickyUntil time.Time + selected Host + prefer Host + fallback Host + + hosts Hosts + + checkOnce sync.Once + optionalFailureOnce sync.Once + + runCh chan struct{} + startWait chan struct{} +} + +// Target returns the SRV target. +func (s *Selector) Target() string { + return s.target +} + +func (s *Selector) getHost() Host { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.selected +} + +// GetHost returns the active host. +// If the selector has not been started before, the selector is initialized. +// This method will block until a host is selected, initialization timeout is reached or the context is canceled. +func (s *Selector) GetHost(ctx context.Context) (Host, error) { + s.start(ctx) + + if host := s.getHost(); host != nil { + return host, nil + } + + ctx, cancel := context.WithTimeout(ctx, s.initTimeout) + defer cancel() + + select { + case <-s.runCh: + return nil, ErrSelectorStopped + case <-ctx.Done(): + return nil, ErrWaitTimedout + case <-s.startWait: + host := s.getHost() + if host == nil { + return nil, ErrHostNotFound + } + + return host, nil + } +} + +func (s *Selector) start(ctx context.Context) { + s.startOnce.Do(func() { + ctx, span := tracer.Start(ctx, "selector.start", trace.WithAttributes( + attribute.String("selector.target", s.target), + attribute.String("selector.service", s.service), + attribute.String("selector.protocol", s.protocol), + attribute.Bool("selector.quick", s.quick), + )) + defer span.End() + + s.logger.Infof("Initializing host selector for '%s'", s.target) + + if s.quick && s.fallback != nil { + s.mu.Lock() + s.selected = s.fallback + s.mu.Unlock() + + span.AddEvent("Quick host selector enabled, selecting fallback address '" + s.fallback.ID() + "'") + + s.logger.Infof("Quick host selector enabled, selected fallback address '%s'", s.fallback.ID()) + + // Start a new context keeping the current span so canceled contexts don't propagate. + ctx = trace.ContextWithSpan(context.Background(), span) + + go s.discoverRecords(ctx) + } else { + s.discoverRecords(ctx) + } + + go s.discovery() + }) +} + +// Start initializes the selector discovery and checking handlers. +func (s *Selector) Start() { + s.start(context.Background()) +} + +// Stop cleans up the service. +func (s *Selector) Stop() { + close(s.runCh) +} + +// selectHost updates the selected host based on the latest host list. +// Selection is made in the following order. +// +// 1. Select to the current host if found, without errors and within sticky period. +// 2. Select the preferred host if found and without errors. +// 3. Select the first host without errors. +// 4. Select the fallback host if configured. +// 5. No change to selected host. +// +// If the last step is reached and no host had previously been selected, the selected host is nil. +func (s *Selector) selectHost(ctx context.Context) { + _, span := tracer.Start(ctx, "selectHost", trace.WithAttributes( + attribute.Bool("host.changed", false), + )) + defer span.End() + + s.mu.Lock() + defer s.mu.Unlock() + + current := s.selected + + var ( + selected Host + first Host + prefer Host + ) + + for _, host := range s.hosts { + if host.Err() != nil { + continue + } + + if first == nil { + first = host + } + + if selected == nil && s.selected != nil && s.selected.ID() == host.ID() { + selected = host + } + + if prefer == nil && s.prefer != nil && s.prefer.ID() == host.ID() { + prefer = host + } + + if first != nil && selected != nil && prefer != nil { + break + } + } + + sticky := time.Now().Before(s.stickyUntil) + + switch { + case selected != nil && sticky: + case prefer != nil: + selected = prefer + case first != nil: + selected = first + case s.fallback != nil: + selected = s.fallback + } + + if current != nil { + span.SetAttributes( + attribute.String("host.current.id", current.ID()), + attribute.Float64("host.current.avg_duration_ms", toMilliseconds(current.AverageDuration())), + attribute.Bool("host.current.sticky", sticky), + ) + + if err := current.Err(); err != nil { + span.SetAttributes( + attribute.String("host.current.error", err.Error()), + ) + } + } + + if selected != nil { + span.SetAttributes( + attribute.String("host.selected.id", selected.ID()), + attribute.Float64("host.selected.avg_duration_ms", toMilliseconds(selected.AverageDuration())), + ) + + if err := selected.Err(); err != nil { + span.SetAttributes( + attribute.String("host.selected.error", err.Error()), + ) + } + } + + if current != selected && selected != nil { + s.selected = selected + + span.SetAttributes(attribute.Bool("host.changed", true)) + + // ensure host doesn't change for 5 check intervals (as long as it's still healthy) + s.stickyUntil = time.Now().Add(5 * s.checkInterval) + + if current == nil { + span.AddEvent("selected host: " + selected.ID()) + + s.logger.Infow("Host Selected", + "selected.host", selected.ID(), + "selected.check_duration_ms", toMilliseconds(selected.AverageDuration()), + ) + } else { + span.AddEvent("host changed: " + current.ID() + " -> " + selected.ID()) + + s.logger.Warnw("Host Changed", + "previous.host", current.ID(), + "previous.check_duration_ms", toMilliseconds(current.AverageDuration()), + "selected.host", selected.ID(), + "selected.check_duration_ms", toMilliseconds(selected.AverageDuration()), + ) + } + } else if selected != nil && current == selected && !sticky { + // If host remains the same but stickiness has expired, reset the sticky counter. + s.stickyUntil = time.Now().Add(5 * s.checkInterval) + } + + if selected == nil { + currentID := "" + if current != nil { + currentID = current.ID() + } + + span.SetStatus(codes.Error, "unable to select host") + + s.logger.Errorw("Unable to select host", "selected.host", currentID) + } +} + +func (s *Selector) discovery() { + ticker := time.NewTicker(s.discoveryInterval) + defer ticker.Stop() + + for { + select { + case <-s.runCh: + return + case <-ticker.C: + } + + s.discoverRecords(context.Background()) + } +} + +func (s *Selector) discoverRecords(ctx context.Context) { + target := s.target + + if strings.Contains(target, ":") { + target, _, _ = net.SplitHostPort(target) + } + + ctx, span := tracer.Start(ctx, "discoverRecords", trace.WithAttributes( + attribute.String("discover.target", target), + attribute.String("discover.service", s.service), + attribute.String("discover.protocol", s.protocol), + )) + defer span.End() + + origCtx := ctx + + ctx, cancel := context.WithTimeout(ctx, s.discoveryTimeout) + defer cancel() + + start := time.Now() + + logger := s.logger.With( + "discover.target", target, + "discover.service", s.service, + "discover.protocol", s.protocol, + ) + + logger.Debugf("Looking for srv records for service '%s' with protocol '%s' for target '%s'", s.service, s.protocol, target) + + cname, srvs, err := s.resolver.LookupSRV(ctx, s.service, s.protocol, target) + + span.SetAttributes(attribute.String("discover.resolved.cname", cname)) + + duration := time.Since(start) + + logger = logger.With( + "discover.resolved.cname", cname, + "discover.runtime_ms", toMilliseconds(duration), + ) + + if err != nil { + span.RecordError(err) + + if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound && s.optional { + s.optionalFailureOnce.Do(func() { + span.AddEvent("no srv records found, however records are optional") + + logger.Warnw("No SRV records found, using target/fallback.") + }) + } else { + span.SetStatus(codes.Error, "Failed to lookup SRV records: "+err.Error()) + + logger.Errorw("Failed to lookup SRV records", "error", err) + } + } + + s.mu.Lock() + + added, removed, matched := diffHosts(s, srvs) + + s.hosts = matched + + s.mu.Unlock() + + for _, host := range added { + span.AddEvent("discovered " + host.ID()) + + logger.Infof("Discovered host '%s'", host.ID()) + } + + for _, host := range removed { + span.AddEvent("removed " + host.ID()) + + logger.Warnf("Host removed '%s'", host.ID()) + } + + s.checkOnce.Do(func() { + span.AddEvent("initializing host checks") + + s.checkHosts(origCtx) + + close(s.startWait) + + go s.watchHosts() + }) +} + +func (s *Selector) watchHosts() { + ticker := time.NewTicker(s.checkInterval) + defer ticker.Stop() + + for { + select { + case <-s.runCh: + return + case <-ticker.C: + } + + s.checkHosts(context.Background()) + } +} + +func (s *Selector) checkHosts(ctx context.Context) { + ctx, span := tracer.Start(ctx, "checkHosts") + defer span.End() + + s.mu.RLock() + + span.SetAttributes( + attribute.Int("check.hosts.count", len(s.hosts)), + attribute.Int("check.concurrency", s.checkConcurrency), + ) + + if len(s.hosts) == 0 { + s.mu.RUnlock() + + s.selectHost(ctx) + + return + } + + hostCh := make(chan Host, len(s.hosts)) + + for _, host := range s.hosts { + hostCh <- host + } + + s.mu.RUnlock() + + close(hostCh) + + var wg sync.WaitGroup + + for i := 0; i < s.checkConcurrency; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for host := range hostCh { + host.check(ctx) + } + }() + } + + wg.Wait() + + s.mu.Lock() + sort.Sort(s.hosts) + s.mu.Unlock() + + s.selectHost(ctx) +} + +// NewSelector creates a new selector service handler. +// The target provided is automatically registered as the default fallback address. +func NewSelector(target, service, protocol string, options ...Option) (*Selector, error) { + sel := &Selector{ + logger: zap.NewNop().Sugar(), + + service: service, + protocol: protocol, + target: target, + + resolver: net.DefaultResolver, + discoveryInterval: 15 * time.Minute, + discoveryTimeout: 2 * time.Second, + + checkCount: 5, + checkInterval: time.Minute, + checkDelay: 200 * time.Millisecond, + checkTimeout: 2 * time.Second, + checkConcurrency: 5, + + initTimeout: 10 * time.Second, + + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + if err := Fallback(target)(sel); err != nil { + return nil, err + } + + for _, opt := range options { + if err := opt(sel); err != nil { + return nil, err + } + } + + return sel, nil +} diff --git a/internal/selecthost/selector_test.go b/internal/selecthost/selector_test.go new file mode 100644 index 00000000..409382b7 --- /dev/null +++ b/internal/selecthost/selector_test.go @@ -0,0 +1,540 @@ +package selecthost + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestSelectorGetHost(t *testing.T) { + t.Parallel() + + t.Run("immediate", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + selected: &host{ + host: "host1.example.com", + }, + } + + selector.startOnce.Do(func() {}) + + host, err := selector.GetHost(context.Background()) + + require.NoError(t, err, "no error expected to be returned") + require.NotNil(t, host, "host expected to be returned") + assert.Equal(t, "host1.example.com", host.Host(), "unexpected host returned") + }) + + t.Run("init timeout", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + initTimeout: 10 * time.Millisecond, + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + defer close(selector.runCh) + defer close(selector.startWait) + + selector.startOnce.Do(func() {}) + + host, err := selector.GetHost(context.Background()) + + require.Error(t, err, "error expected to be returned") + require.Nil(t, host, "no host expected to be returned") + + assert.ErrorIs(t, err, ErrWaitTimedout, "unexpected error returned") + }) + + t.Run("selector shutdown", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + initTimeout: 10 * time.Millisecond, + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + defer close(selector.startWait) + + selector.startOnce.Do(func() {}) + + go func() { + time.Sleep(2 * time.Millisecond) + + close(selector.runCh) + }() + + host, err := selector.GetHost(context.Background()) + + require.Error(t, err, "error expected to be returned") + require.Nil(t, host, "no host expected to be returned") + + assert.ErrorIs(t, err, ErrSelectorStopped, "unexpected error returned") + }) + + t.Run("start finish without host", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + initTimeout: 10 * time.Millisecond, + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + defer close(selector.runCh) + + selector.startOnce.Do(func() {}) + + go func() { + time.Sleep(2 * time.Millisecond) + + close(selector.startWait) + }() + + host, err := selector.GetHost(context.Background()) + + require.Error(t, err, "error expected to be returned") + require.Nil(t, host, "no host expected to be returned") + + assert.ErrorIs(t, err, ErrHostNotFound, "unexpected error returned") + }) + + t.Run("start finish with host", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + initTimeout: 10 * time.Millisecond, + runCh: make(chan struct{}), + startWait: make(chan struct{}), + } + + defer close(selector.runCh) + + selector.startOnce.Do(func() {}) + + go func() { + time.Sleep(2 * time.Millisecond) + + selector.mu.Lock() + defer selector.mu.Unlock() + + selector.selected = &host{ + host: "host1.example.com", + } + + close(selector.startWait) + }() + + host, err := selector.GetHost(context.Background()) + + require.NoError(t, err, "no error expected to be returned") + require.NotNil(t, host, "host expected to be returned") + + assert.Equal(t, "host1.example.com", host.Host(), "unexpected host returned") + }) +} + +type testResolver struct { + cname string + records []*net.SRV + err error + + requestedService string + requestedProtocol string + requestedTarget string +} + +func (r *testResolver) LookupSRV(_ context.Context, service, protocol, target string) (string, []*net.SRV, error) { + r.requestedService = service + r.requestedProtocol = protocol + r.requestedTarget = target + + return r.cname, r.records, r.err +} + +func TestSelectorStart(t *testing.T) { + t.Parallel() + + t.Run("quick", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + quick: true, + fallback: &host{ + host: "fallback.example.com", + }, + resolver: &testResolver{}, + discoveryInterval: time.Second, + runCh: make(chan struct{}), + } + + close(selector.runCh) + + selector.checkOnce.Do(func() {}) + + selector.start(context.Background()) + + host := selector.getHost() + + require.NotNil(t, host, "expected host to not be nil") + assert.Equal(t, "fallback.example.com", host.Host(), "unexpected host returned") + }) + + t.Run("not quick", func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + quick: false, + fallback: &host{ + host: "fallback.example.com", + }, + resolver: &testResolver{}, + discoveryInterval: time.Second, + runCh: make(chan struct{}), + } + + close(selector.runCh) + + selector.checkOnce.Do(func() {}) + + selector.start(context.Background()) + + host := selector.getHost() + + require.Nil(t, host, "expected host to be nil") + }) +} + +// nolint: govet +func TestDiscoverRecords(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + target string + records []*net.SRV + expectTarget string + expectHosts []string + }{ + { + "with port", + "iam.example.com:1234", + []*net.SRV{ + {"host1.example.com", 80, 0, 0}, + }, + "iam.example.com", + []string{ + "host1.example.com:80:0:0", + }, + }, + { + "without port", + "iam.example.com", + []*net.SRV{ + {"host1.example.com", 80, 0, 0}, + }, + "iam.example.com", + []string{ + "host1.example.com:80:0:0", + }, + }, + { + "no records", + "iam.example.com", + []*net.SRV{}, + "iam.example.com", + []string{}, + }, + { + "priority changes", + "iam.example.com", + []*net.SRV{ + {"new1.example.com", 80, 10, 10}, + {"old1.example.com", 80, 20, 10}, + }, + "iam.example.com", + []string{ + "new1.example.com:80:10:10", + "old1.example.com:80:20:10", + }, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + resolver := &testResolver{ + records: tc.records, + } + + if len(tc.records) == 0 { + resolver.err = &net.DNSError{ + IsNotFound: true, + } + } + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + resolver: resolver, + target: tc.target, + } + + selector.fallback = newHost(selector, "fallback.example.com", "80", net.SRV{}) + selector.hosts = Hosts{ + newHost(selector, "old1.example.com", "80", net.SRV{"old1.example.com", 80, 10, 10}), + } + + selector.checkOnce.Do(func() {}) + + selector.discoverRecords(context.Background()) + + hosts := make([]string, len(selector.hosts)) + for i, host := range selector.hosts { + hosts[i] = host.ID() + } + + assert.Equal(t, tc.expectTarget, resolver.requestedTarget, "unexpected target queried") + assert.Equal(t, tc.expectHosts, hosts, "unexpected hosts returned") + }) + } +} + +func TestSelectorSelectHost(t *testing.T) { + t.Parallel() + + var hostErr = fmt.Errorf("%w: host error", errTestBase) + + type testHost struct { + host string + avg time.Duration + err error + } + + testCases := []struct { + name string + hosts []testHost + current string + sticky bool + prefer string + fallback string + expectSelected string + }{ + { + "first healthy selection", + []testHost{ + {"host1.example.com", 10, hostErr}, + {"host2.example.com", 20, nil}, + }, + "", + false, + "", + "", + "host2.example.com", + }, + { + "current sticky", + []testHost{ + {"host1.example.com", 10, nil}, + {"host2.example.com", 20, nil}, + }, + "host2.example.com", + true, + "", + "", + "host2.example.com", + }, + { + "current not sticky", + []testHost{ + {"host1.example.com", 10, nil}, + {"host2.example.com", 20, nil}, + }, + "host2.example.com", + false, + "", + "", + "host1.example.com", + }, + { + "use preferred", + []testHost{ + {"host1.example.com", 10, nil}, + {"host2.example.com", 20, nil}, + }, + "host1.example.com", + false, + "host2.example.com", + "", + "host2.example.com", + }, + { + "fallback", + []testHost{ + {"host1.example.com", 10, hostErr}, + {"host2.example.com", 20, hostErr}, + }, + "host2.example.com", + false, + "", + "host3.example.com", + "host3.example.com", + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + hosts := make(Hosts, len(tc.hosts)) + + for i, h := range tc.hosts { + hosts[i] = &host{ + id: h.host, + host: h.host, + err: h.err, + lastCheck: Results{ + Checks: 1, + TotalDuration: h.avg, + }, + } + } + + var selected, prefer, fallback Host + + if tc.current != "" { + selected = &host{ + id: tc.current, + host: tc.current, + } + } + + if tc.prefer != "" { + prefer = &host{ + id: tc.prefer, + host: tc.prefer, + } + } + + if tc.fallback != "" { + fallback = &host{ + id: tc.fallback, + host: tc.fallback, + } + } + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + + hosts: hosts, + selected: selected, + prefer: prefer, + fallback: fallback, + } + + if tc.sticky { + selector.stickyUntil = time.Now().Add(time.Second) + } + + selector.selectHost(context.Background()) + + if tc.expectSelected == "" { + require.Nil(t, selector.selected, "expected no host to be selected") + + return + } + + require.NotNil(t, selector.selected, "expected a host to be selected") + + assert.Equal(t, tc.expectSelected, selector.selected.Host(), "unexpected host selected") + }) + } +} + +func TestSelectorCheckHosts(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + withHosts int + }{ + { + "no hosts", + 0, + }, + { + "multiple hosts", + 2, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + selector := &Selector{ + logger: zap.NewNop().Sugar(), + checkConcurrency: 2, + checkCount: 5, + checkTimeout: time.Second, + runCh: make(chan struct{}), + } + + defer close(selector.runCh) + + var checkCount atomic.Uint32 + + hosts := make(Hosts, tc.withHosts) + + for i := 0; i < tc.withHosts; i++ { + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + checkCount.Add(1) + })) + defer srv.Close() + + h, p, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err, "no error expected splitting test server address") + + hosts[i] = &host{ + selector: selector, + id: h + ":" + p, + host: h, + port: p, + } + } + + selector.hosts = hosts + + selector.checkHosts(context.Background()) + + assert.Equal(t, 5*tc.withHosts, int(checkCount.Load()), "unexpected number of requests fetched") + + if tc.withHosts != 0 { + assert.NotNil(t, selector.selected, "expected a host to be selected") + } else { + assert.Nil(t, selector.selected, "no host expected to be selected") + } + }) + } +}