diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index d96620c9..318bf368 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -14,6 +14,7 @@ package main import ( + "context" "fmt" "os" "os/signal" @@ -55,6 +56,10 @@ func main() { // Zerolog uses json formatting by default, so change that to a human-readable format instead log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: timeFormat, NoColor: true}) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGTERM) defer signal.Stop(signalChan) @@ -93,6 +98,12 @@ func main() { log.Fatal().Err(err).Msg("Unable to instantiate observability metrics,") } + probes, err := observability.InitProbes(nthConfig.EnableProbes, nthConfig.ProbesPort, nthConfig.ProbesEndpoint) + if err != nil { + nthConfig.Print() + log.Fatal().Err(err).Msg("Unable to instantiate probes service,") + } + imds := ec2metadata.New(nthConfig.MetadataURL, nthConfig.MetadataTries) interruptionEventStore := interruptioneventstore.New(nthConfig) @@ -218,6 +229,9 @@ func main() { } log.Log().Msg("AWS Node Termination Handler is shutting down") wg.Wait() + if err = probes.Shutdown(ctx); err != nil { + log.Err(err).Msg("Failed to stop probes server") + } log.Debug().Msg("all event processors finished") } diff --git a/pkg/config/config.go b/pkg/config/config.go index bd2d400f..d55cb74d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -78,6 +78,13 @@ const ( // https://github.com/prometheus/prometheus/wiki/Default-port-allocations prometheusPortDefault = 9092 prometheusPortConfigKey = "PROMETHEUS_SERVER_PORT" + // probes + enableProbesDefault = false + enableProbesConfigKey = "ENABLE_PROBES_SERVER" + probesPortDefault = 8080 + probesPortConfigKey = "PROBES_SERVER_PORT" + probesEndpointDefault = "healthz" + probesEndpointConfigKey = "PROBES_SERVER_ENDPOINT" region = "" awsRegionConfigKey = "AWS_REGION" awsEndpointConfigKey = "AWS_ENDPOINT" @@ -115,6 +122,9 @@ type Config struct { UptimeFromFile string EnablePrometheus bool PrometheusPort int + EnableProbes bool + ProbesPort int + ProbesEndpoint string AWSRegion string AWSEndpoint string QueueURL string @@ -162,6 +172,9 @@ func ParseCliArgs() (config Config, err error) { flag.StringVar(&config.UptimeFromFile, "uptime-from-file", getEnv(uptimeFromFileConfigKey, uptimeFromFileDefault), "If specified, read system uptime from the file path (useful for testing).") flag.BoolVar(&config.EnablePrometheus, "enable-prometheus-server", getBoolEnv(enablePrometheusConfigKey, enablePrometheusDefault), "If true, a http server is used for exposing prometheus metrics in /metrics endpoint.") flag.IntVar(&config.PrometheusPort, "prometheus-server-port", getIntEnv(prometheusPortConfigKey, prometheusPortDefault), "The port for running the prometheus http server.") + flag.BoolVar(&config.EnableProbes, "enable-probes-server", getBoolEnv(enableProbesConfigKey, enableProbesDefault), "If true, a http server is used for exposing probes in /probes endpoint.") + flag.IntVar(&config.ProbesPort, "probes-server-port", getIntEnv(probesPortConfigKey, probesPortDefault), "The port for running the probes http server.") + flag.StringVar(&config.ProbesEndpoint, "probes-server-endpoint", getEnv(probesEndpointConfigKey, probesEndpointDefault), "If specified, use this endpoint to make liveness probe") flag.StringVar(&config.AWSRegion, "aws-region", getEnv(awsRegionConfigKey, ""), "If specified, use the AWS region for AWS API calls") flag.StringVar(&config.AWSEndpoint, "aws-endpoint", getEnv(awsEndpointConfigKey, ""), "[testing] If specified, use the AWS endpoint to make API calls") flag.StringVar(&config.QueueURL, "queue-url", getEnv(queueURLConfigKey, ""), "Listens for messages on the specified SQS queue URL") diff --git a/pkg/observability/probes.go b/pkg/observability/probes.go new file mode 100644 index 00000000..2d7d3a04 --- /dev/null +++ b/pkg/observability/probes.go @@ -0,0 +1,62 @@ +package observability + +import ( + "context" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +// Probes represents probes +type Probes struct { + server *http.Server +} + +// InitProbes will initialize, register and expose, via http server, the probes. +func InitProbes(enabled bool, port int, endpoint string) (Probes, error) { + if !enabled { + return Probes{}, nil + } + + clear := fmt.Sprintf("%s", strings.TrimSpace(endpoint)) + log.Info().Msgf("Starting to serve handler /%s, port %d", clear, port) + http.HandleFunc(fmt.Sprintf("/%s", clear), LivenessHandler) + + probes := Probes{ + server: &http.Server{ + Addr: net.JoinHostPort("", strconv.Itoa(port)), + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + }, + } + + go func() { + if err := probes.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Err(err).Msg("Failed to listen and serve http server") + } + }() + + return probes, nil +} + +func LivenessHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(http.StatusText(http.StatusOK))) +} + +func (p Probes) Shutdown(ctx context.Context) error { + if p.server != nil { + return nil + } + + if err := p.server.Shutdown(ctx); err != nil { + return err + } + + return nil +} diff --git a/pkg/observability/probes_test.go b/pkg/observability/probes_test.go new file mode 100644 index 00000000..22dafc7b --- /dev/null +++ b/pkg/observability/probes_test.go @@ -0,0 +1,25 @@ +package observability + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestLivenessHandler(t *testing.T) { + req := httptest.NewRequest("GET", "/healthz", nil) + rr := httptest.NewRecorder() + handler := http.HandlerFunc(LivenessHandler) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + if body := rr.Body.String(); body != http.StatusText(http.StatusOK) { + t.Errorf("handler returned wrong body: got %v want %v", + body, http.StatusText(http.StatusOK)) + } +}