diff --git a/internal/features/features.go b/internal/features/features.go index 3fad5d2b9..63ac8ffb5 100644 --- a/internal/features/features.go +++ b/internal/features/features.go @@ -38,6 +38,10 @@ const ( // // Ref: https://github.com/helm/helm/security/advisories/GHSA-pwcw-6f5g-gxf8 AllowDNSLookups = "AllowDNSLookups" + + // OOMWatch enables the OOM watcher, which will gracefully shut down the controller + // when the memory usage exceeds the configured limit. This is disabled by default. + OOMWatch = "OOMWatch" ) var features = map[string]bool{ @@ -50,6 +54,9 @@ var features = map[string]bool{ // AllowDNSLookups // opt-in from v0.31 AllowDNSLookups: false, + // OOMWatch + // opt-in from v0.31 + OOMWatch: false, } // FeatureGates contains a list of all supported feature gates and diff --git a/internal/oomwatch/watch.go b/internal/oomwatch/watch.go new file mode 100644 index 000000000..f60ba6c42 --- /dev/null +++ b/internal/oomwatch/watch.go @@ -0,0 +1,172 @@ +/* +Copyright 2023 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package oomwatch provides a way to detect near OOM conditions. +package oomwatch + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-logr/logr" +) + +const ( + // DefaultCgroupPath is the default path to the cgroup directory. + DefaultCgroupPath = "/sys/fs/cgroup/" + // MemoryMaxFile is the cgroup memory.max filename. + MemoryMaxFile = "memory.max" + // MemoryCurrentFile is the cgroup memory.current filename. + MemoryCurrentFile = "memory.current" +) + +// Watcher can be used to detect near OOM conditions. +type Watcher struct { + // memoryMax is the maximum amount of memory that can be used by the system. + memoryMax uint64 + // memoryCurrentPath is the cgroup memory.current filepath. + memoryCurrentPath string + // memoryUsagePercentThreshold is the threshold at which the system is + // considered to be near OOM. + memoryUsagePercentThreshold uint8 + // interval is the interval at which to check for OOM. + interval time.Duration + // logger is the logger to use. + logger logr.Logger + + // ctx is the context that is canceled when OOM is detected. + ctx context.Context + // cancel is the function that cancels the context. + cancel context.CancelFunc + // once is used to ensure that Watch is only called once. + once sync.Once +} + +// New returns a new Watcher. +func New(memoryMaxPath, memoryCurrentPath string, memoryUsagePercentThreshold uint8, interval time.Duration, logger logr.Logger) (*Watcher, error) { + if memoryUsagePercentThreshold < 1 || memoryUsagePercentThreshold > 100 { + return nil, fmt.Errorf("memory usage percent threshold must be between 1 and 100, got %d", memoryUsagePercentThreshold) + } + + if minInterval := 50 * time.Millisecond; interval < minInterval { + return nil, fmt.Errorf("interval must be at least %s, got %s", minInterval, interval) + } + + if _, err := os.Lstat(memoryCurrentPath); err != nil { + return nil, fmt.Errorf("failed to stat memory.current %q: %w", memoryCurrentPath, err) + } + + memoryMax, err := readUintFromFile(memoryMaxPath) + if err != nil { + return nil, fmt.Errorf("failed to read memory.max %q: %w", memoryMaxPath, err) + } + + return &Watcher{ + memoryMax: memoryMax, + memoryCurrentPath: memoryCurrentPath, + memoryUsagePercentThreshold: memoryUsagePercentThreshold, + interval: interval, + logger: logger, + }, nil +} + +// NewDefault returns a new Watcher with default path values. +func NewDefault(memoryUsagePercentThreshold uint8, interval time.Duration, logger logr.Logger) (*Watcher, error) { + return New( + filepath.Join(DefaultCgroupPath, MemoryMaxFile), + filepath.Join(DefaultCgroupPath, MemoryCurrentFile), + memoryUsagePercentThreshold, + interval, + logger, + ) +} + +// Watch returns a context that is canceled when the system reaches the +// configured memory usage threshold. Calling Watch multiple times will return +// the same context. +func (w *Watcher) Watch(ctx context.Context) context.Context { + w.once.Do(func() { + w.ctx, w.cancel = context.WithCancel(ctx) + go w.watchForNearOOM(ctx) + }) + return w.ctx +} + +// watchForNearOOM polls the memory.current file on the configured interval +// and cancels the context within Watcher when the system is near OOM. +// It is expected that this function is called in a goroutine. Canceling +// provided context will cause the goroutine to exit. +func (w *Watcher) watchForNearOOM(ctx context.Context) { + t := time.NewTicker(w.interval) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + w.logger.Info("Shutdown signal received, stopping watch for near OOM") + return + case <-t.C: + current, err := readUintFromFile(w.memoryCurrentPath) + if err != nil { + w.logger.Error(err, "Failed to read current memory usage, skipping check") + continue + } + + currentPercentage := float64(current) / float64(w.memoryMax) * 100 + if currentPercentage >= float64(w.memoryUsagePercentThreshold) { + w.logger.Info(fmt.Sprintf("Memory usage is near OOM (%s/%s), shutting down", + formatSize(current), formatSize(w.memoryMax))) + w.cancel() + return + } + w.logger.V(2).Info(fmt.Sprintf("Current memory usage %s/%s (%.2f%% out of %d%%)", + formatSize(current), formatSize(w.memoryMax), currentPercentage, w.memoryUsagePercentThreshold)) + } + } +} + +// readUintFromFile reads an uint64 from the file at the given path. +func readUintFromFile(path string) (uint64, error) { + b, err := os.ReadFile(path) + if err != nil { + return 0, err + } + return strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) +} + +// formatSize formats the given size in bytes to a human-readable format. +func formatSize(b uint64) string { + if b == 0 { + return "-" + } + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := uint64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", + float64(b)/float64(div), "KMGTPE"[exp]) +} diff --git a/internal/oomwatch/watch_test.go b/internal/oomwatch/watch_test.go new file mode 100644 index 000000000..d53ab1156 --- /dev/null +++ b/internal/oomwatch/watch_test.go @@ -0,0 +1,250 @@ +/* +Copyright 2023 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oomwatch + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/go-logr/logr" + . "github.com/onsi/gomega" +) + +func TestNew(t *testing.T) { + t.Run("success", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryMax := filepath.Join(t.TempDir(), MemoryMaxFile) + g.Expect(os.WriteFile(mockMemoryMax, []byte("1000000000"), 0o640)).To(Succeed()) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + _, err := os.Create(mockMemoryCurrent) + g.Expect(err).ToNot(HaveOccurred()) + + w, err := New(mockMemoryMax, mockMemoryCurrent, 1, time.Second, logr.Discard()) + g.Expect(err).ToNot(HaveOccurred()) + + g.Expect(w).To(BeEquivalentTo(&Watcher{ + memoryMax: uint64(1000000000), + memoryCurrentPath: mockMemoryCurrent, + memoryUsagePercentThreshold: 1, + interval: time.Second, + logger: logr.Discard(), + })) + }) + + t.Run("validation", func(t *testing.T) { + t.Run("memory usage percentage threshold", func(t *testing.T) { + t.Run("less than 1", func(t *testing.T) { + g := NewWithT(t) + + _, err := New("", "", 0, 0, logr.Discard()) + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError("memory usage percent threshold must be between 1 and 100, got 0")) + }) + t.Run("greater than 100", func(t *testing.T) { + g := NewWithT(t) + + _, err := New("", "", 101, 0, logr.Discard()) + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError("memory usage percent threshold must be between 1 and 100, got 101")) + }) + }) + + t.Run("interval", func(t *testing.T) { + t.Run("less than 50ms", func(t *testing.T) { + g := NewWithT(t) + + _, err := New("", "", 1, 49*time.Millisecond, logr.Discard()) + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError("interval must be at least 50ms, got 49ms")) + }) + }) + + t.Run("memory current path", func(t *testing.T) { + t.Run("does not exist", func(t *testing.T) { + g := NewWithT(t) + + _, err := New("", "", 1, 50*time.Second, logr.Discard()) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("failed to stat memory.current \"\": lstat : no such file or directory")) + }) + }) + + t.Run("memory max path", func(t *testing.T) { + t.Run("does not exist", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryMaxFile) + _, err := os.Create(mockMemoryCurrent) + g.Expect(err).NotTo(HaveOccurred()) + + _, err = New("", mockMemoryCurrent, 1, 50*time.Second, logr.Discard()) + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(ContainSubstring("failed to read memory.max \"\": open : no such file or directory")) + }) + }) + }) +} + +func TestWatcher_Watch(t *testing.T) { + t.Run("returns same context", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryMax := filepath.Join(t.TempDir(), MemoryMaxFile) + g.Expect(os.WriteFile(mockMemoryMax, []byte("1000000000"), 0o640)).To(Succeed()) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + _, err := os.Create(mockMemoryCurrent) + g.Expect(err).ToNot(HaveOccurred()) + + w, err := New(mockMemoryMax, mockMemoryCurrent, 1, time.Second, logr.Discard()) + g.Expect(err).ToNot(HaveOccurred()) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + g.Expect(w.Watch(ctx)).To(Equal(w.Watch(ctx))) + }) + + t.Run("cancels context when memory usage is above threshold", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("1000000000"), 0o640)).To(Succeed()) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + w := &Watcher{ + memoryMax: uint64(1000000000), + memoryCurrentPath: mockMemoryCurrent, + memoryUsagePercentThreshold: 95, + interval: 10 * time.Millisecond, + logger: logr.Discard(), + ctx: ctx, + cancel: cancel, + } + + go func() { + <-w.ctx.Done() + g.Expect(w.ctx.Err()).To(MatchError(context.Canceled)) + }() + }) +} + +func TestWatcher_watchForNearOOM(t *testing.T) { + t.Run("does not cancel context when memory usage is below threshold", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("940000000"), 0o640)).To(Succeed()) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + w := &Watcher{ + memoryMax: uint64(1000000000), + memoryCurrentPath: mockMemoryCurrent, + memoryUsagePercentThreshold: 95, + interval: 500 * time.Millisecond, + logger: logr.Discard(), + ctx: ctx, + cancel: cancel, + } + + innerCtx, innerCancel := context.WithCancel(context.Background()) + go w.watchForNearOOM(innerCtx) + + select { + case <-ctx.Done(): + t.Fatal("context should not have been cancelled") + case <-time.After(1 * time.Second): + // This also tests if the inner context stops the watcher. + innerCancel() + } + }) + + t.Run("cancels context when memory usage is above threshold", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("0"), 0o640)).To(Succeed()) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + w := &Watcher{ + memoryMax: uint64(1000000000), + memoryCurrentPath: mockMemoryCurrent, + memoryUsagePercentThreshold: 95, + interval: 500 * time.Millisecond, + logger: logr.Discard(), + ctx: ctx, + cancel: cancel, + } + + go w.watchForNearOOM(context.TODO()) + + select { + case <-ctx.Done(): + case <-time.After(500 * time.Millisecond): + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("950000001"), 0o640)).To(Succeed()) + case <-time.After(2 * time.Second): + t.Fatal("context was not cancelled") + } + }) + + t.Run("continues to attempt to read memory.current", func(t *testing.T) { + g := NewWithT(t) + + mockMemoryCurrent := filepath.Join(t.TempDir(), MemoryCurrentFile) + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("0"), 0o000)).To(Succeed()) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + w := &Watcher{ + memoryMax: uint64(1000000000), + memoryCurrentPath: mockMemoryCurrent, + memoryUsagePercentThreshold: 95, + interval: 500 * time.Millisecond, + logger: logr.Discard(), + ctx: ctx, + cancel: cancel, + } + + go w.watchForNearOOM(context.TODO()) + + var readable bool + select { + case <-ctx.Done(): + if !readable { + t.Fatal("context was cancelled before memory.current was readable") + } + case <-time.After(1 * time.Second): + g.Expect(os.Chmod(mockMemoryCurrent, 0o640)).To(Succeed()) + g.Expect(os.WriteFile(mockMemoryCurrent, []byte("950000001"), 0o640)).To(Succeed()) + readable = true + case <-time.After(2 * time.Second): + t.Fatal("context was not cancelled") + } + }) +} diff --git a/main.go b/main.go index 04f9fa811..70b13c5ec 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ package main import ( "fmt" + "github.com/fluxcd/helm-controller/internal/oomwatch" "os" "time" @@ -84,18 +85,33 @@ func main() { aclOptions acl.Options leaderElectionOptions leaderelection.Options rateLimiterOptions helper.RateLimiterOptions + oomWatchInterval time.Duration + oomWatchMemoryThreshold uint8 ) - flag.StringVar(&metricsAddr, "metrics-addr", ":8080", "The address the metric endpoint binds to.") - flag.StringVar(&eventsAddr, "events-addr", "", "The address of the events receiver.") - flag.StringVar(&healthAddr, "health-addr", ":9440", "The address the health endpoint binds to.") - flag.IntVar(&concurrent, "concurrent", 4, "The number of concurrent HelmRelease reconciles.") - flag.DurationVar(&requeueDependency, "requeue-dependency", 30*time.Second, "The interval at which failing dependencies are reevaluated.") - flag.DurationVar(&gracefulShutdownTimeout, "graceful-shutdown-timeout", 600*time.Second, "The duration given to the reconciler to finish before forcibly stopping.") + flag.StringVar(&metricsAddr, "metrics-addr", ":8080", + "The address the metric endpoint binds to.") + flag.StringVar(&eventsAddr, "events-addr", "", + "The address of the events receiver.") + flag.StringVar(&healthAddr, "health-addr", ":9440", + "The address the health endpoint binds to.") + flag.IntVar(&concurrent, "concurrent", 4, + "The number of concurrent HelmRelease reconciles.") + flag.DurationVar(&requeueDependency, "requeue-dependency", 30*time.Second, + "The interval at which failing dependencies are reevaluated.") + flag.DurationVar(&gracefulShutdownTimeout, "graceful-shutdown-timeout", 600*time.Second, + "The duration given to the reconciler to finish before forcibly stopping.") flag.BoolVar(&watchAllNamespaces, "watch-all-namespaces", true, "Watch for custom resources in all namespaces, if set to false it will only watch the runtime namespace.") - flag.IntVar(&httpRetry, "http-retry", 9, "The maximum number of retries when failing to fetch artifacts over HTTP.") - flag.StringVar(&intkube.DefaultServiceAccountName, "default-service-account", "", "Default service account used for impersonation.") + flag.IntVar(&httpRetry, "http-retry", 9, + "The maximum number of retries when failing to fetch artifacts over HTTP.") + flag.StringVar(&intkube.DefaultServiceAccountName, "default-service-account", "", + "Default service account used for impersonation.") + flag.Uint8Var(&oomWatchMemoryThreshold, "oom-watch-memory-threshold", 95, + "The memory threshold in percentage at which the OOM watcher will trigger a graceful shutdown. Requires feature gate 'OOMWatch' to be enabled.") + flag.DurationVar(&oomWatchInterval, "oom-watch-interval", 500*time.Millisecond, + "The interval at which the OOM watcher will check for memory usage. Requires feature gate 'OOMWatch' to be enabled.") + clientOptions.BindFlags(flag.CommandLine) logOptions.BindFlags(flag.CommandLine) aclOptions.BindFlags(flag.CommandLine) @@ -103,6 +119,7 @@ func main() { rateLimiterOptions.BindFlags(flag.CommandLine) kubeConfigOpts.BindFlags(flag.CommandLine) featureGates.BindFlags(flag.CommandLine) + flag.Parse() ctrl.SetLogger(logger.NewLogger(logOptions)) @@ -122,7 +139,7 @@ func main() { watchNamespace = os.Getenv("RUNTIME_NAMESPACE") } - disableCacheFor := []ctrlclient.Object{} + var disableCacheFor []ctrlclient.Object shouldCache, err := features.Enabled(features.CacheSecretsAndConfigMaps) if err != nil { setupLog.Error(err, "unable to check feature gate CacheSecretsAndConfigMaps") @@ -190,8 +207,19 @@ func main() { } // +kubebuilder:scaffold:builder + ctx := ctrl.SetupSignalHandler() + if ok, _ := features.Enabled(features.OOMWatch); ok { + setupLog.Info("setting up OOM watcher") + ow, err := oomwatch.NewDefault(oomWatchMemoryThreshold, oomWatchInterval, ctrl.Log.WithName("OOMwatch")) + if err != nil { + setupLog.Error(err, "unable to setup OOM watcher") + os.Exit(1) + } + ctx = ow.Watch(ctx) + } + setupLog.Info("starting manager") - if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { + if err := mgr.Start(ctx); err != nil { setupLog.Error(err, "problem running manager") os.Exit(1) }