diff --git a/pkg/cache/decorator.go b/pkg/cache/decorator.go new file mode 100644 index 000000000000..5740dec7eb0b --- /dev/null +++ b/pkg/cache/decorator.go @@ -0,0 +1,64 @@ +package cache + +// WithMetrics is a decorator that adds metrics collection to any Cache implementation. +type WithMetrics[T any] struct { + wrapped Cache[T] + metrics BaseMetricsCollector + cacheName string +} + +// NewCacheWithMetrics creates a new WithMetrics decorator that wraps the provided Cache +// and collects metrics using the provided BaseMetricsCollector. +// The cacheName parameter is used to identify the cache in the collected metrics. +func NewCacheWithMetrics[T any](wrapped Cache[T], metrics BaseMetricsCollector, cacheName string) *WithMetrics[T] { + return &WithMetrics[T]{ + wrapped: wrapped, + metrics: metrics, + cacheName: cacheName, + } +} + +// Set sets the value for the given key in the cache. It also records a set metric +// for the cache using the provided metrics collector and cache name. +func (c *WithMetrics[T]) Set(key string, val T) { + c.metrics.RecordSet(c.cacheName) + c.wrapped.Set(key, val) +} + +// Get retrieves the value for the given key from the underlying cache. It also records +// a hit or miss metric for the cache using the provided metrics collector and cache name. +func (c *WithMetrics[T]) Get(key string) (T, bool) { + val, found := c.wrapped.Get(key) + if found { + c.metrics.RecordHit(c.cacheName) + } else { + c.metrics.RecordMiss(c.cacheName) + } + return val, found +} + +// Exists checks if the given key exists in the cache. It records a hit or miss metric +// for the cache using the provided metrics collector and cache name. +func (c *WithMetrics[T]) Exists(key string) bool { + found := c.wrapped.Exists(key) + if found { + c.metrics.RecordHit(c.cacheName) + } else { + c.metrics.RecordMiss(c.cacheName) + } + return found +} + +// Delete removes the value for the given key from the cache. It also records a delete metric +// for the cache using the provided metrics collector and cache name. +func (c *WithMetrics[T]) Delete(key string) { + c.wrapped.Delete(key) + c.metrics.RecordDelete(c.cacheName) +} + +// Clear removes all entries from the cache. It also records a clear metric +// for the cache using the provided metrics collector and cache name. +func (c *WithMetrics[T]) Clear() { + c.wrapped.Clear() + c.metrics.RecordClear(c.cacheName) +} diff --git a/pkg/cache/lru/lru.go b/pkg/cache/lru/lru.go new file mode 100644 index 000000000000..f8e117a8a769 --- /dev/null +++ b/pkg/cache/lru/lru.go @@ -0,0 +1,120 @@ +// Package lru provides a generic, size-limited, LRU (Least Recently Used) cache with optional +// metrics collection and reporting. It wraps the golang-lru/v2 caching library, adding support for custom +// metrics tracking cache hits, misses, evictions, and other cache operations. +// +// This package supports configuring key aspects of cache behavior, including maximum cache size, +// and custom metrics collection. +package lru + +import ( + "fmt" + + lru "github.com/hashicorp/golang-lru/v2" + + "github.com/trufflesecurity/trufflehog/v3/pkg/cache" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" +) + +// collector is an interface that extends cache.BaseMetricsCollector +// and adds methods for recording cache evictions. +type collector interface { + cache.BaseMetricsCollector + + RecordEviction(cacheName string) +} + +// Cache is a generic LRU-sized cache that stores key-value pairs with a maximum size limit. +// It wraps the lru.Cache library and adds support for custom metrics collection. +type Cache[T any] struct { + cache *lru.Cache[string, T] + + cacheName string + capacity int + metrics collector +} + +// Option defines a functional option for configuring the Cache. +type Option[T any] func(*Cache[T]) + +// WithMetricsCollector is a functional option to set a custom metrics collector. +// It sets the metrics field of the Cache. +func WithMetricsCollector[T any](collector collector) Option[T] { + return func(lc *Cache[T]) { lc.metrics = collector } +} + +// WithCapacity is a functional option to set the maximum number of items the cache can hold. +// If the capacity is not set, the default value (128_000) is used. +func WithCapacity[T any](capacity int) Option[T] { + return func(lc *Cache[T]) { lc.capacity = capacity } +} + +// NewCache creates a new Cache with optional configuration parameters. +// It takes a cache name and a variadic list of options. +func NewCache[T any](cacheName string, opts ...Option[T]) (*Cache[T], error) { + // Default values for cache configuration. + const defaultSize = 128_000 + + sizedLRU := &Cache[T]{ + metrics: NewSizedLRUMetricsCollector(common.MetricsNamespace, common.MetricsSubsystem), + cacheName: cacheName, + } + + for _, opt := range opts { + opt(sizedLRU) + } + + // Provide a evict callback function to record evictions. + onEvicted := func(string, T) { + sizedLRU.metrics.RecordEviction(sizedLRU.cacheName) + } + + lcache, err := lru.NewWithEvict[string, T](defaultSize, onEvicted) + if err != nil { + return nil, fmt.Errorf("failed to create lrusized cache: %w", err) + } + + sizedLRU.cache = lcache + + return sizedLRU, nil +} + +// Set adds a key-value pair to the cache. +func (lc *Cache[T]) Set(key string, val T) { + lc.cache.Add(key, val) + lc.metrics.RecordSet(lc.cacheName) +} + +// Get retrieves a value from the cache by key. +func (lc *Cache[T]) Get(key string) (T, bool) { + value, found := lc.cache.Get(key) + if found { + lc.metrics.RecordHit(lc.cacheName) + return value, true + } + lc.metrics.RecordMiss(lc.cacheName) + var zero T + return zero, false +} + +// Exists checks if a key exists in the cache. +func (lc *Cache[T]) Exists(key string) bool { + _, found := lc.cache.Get(key) + if found { + lc.metrics.RecordHit(lc.cacheName) + } else { + lc.metrics.RecordMiss(lc.cacheName) + } + return found +} + +// Delete removes a key from the cache. +func (lc *Cache[T]) Delete(key string) { + lc.cache.Remove(key) + lc.metrics.RecordDelete(lc.cacheName) +} + +// Clear removes all keys from the cache. +func (lc *Cache[T]) Clear() { + lc.cache.Purge() + lc.metrics.RecordClear(lc.cacheName) +} diff --git a/pkg/cache/lru/lru_test.go b/pkg/cache/lru/lru_test.go new file mode 100644 index 000000000000..1bb08e04f84a --- /dev/null +++ b/pkg/cache/lru/lru_test.go @@ -0,0 +1,160 @@ +package lru + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockCollector struct{ mock.Mock } + +func (m *mockCollector) RecordHits(cacheName string, hits uint64) { m.Called(cacheName, hits) } + +func (m *mockCollector) RecordMisses(cacheName string, misses uint64) { m.Called(cacheName, misses) } + +func (m *mockCollector) RecordEviction(cacheName string) { m.Called(cacheName) } + +func (m *mockCollector) RecordSet(cacheName string) { m.Called(cacheName) } + +func (m *mockCollector) RecordHit(cacheName string) { m.Called(cacheName) } + +func (m *mockCollector) RecordMiss(cacheName string) { m.Called(cacheName) } + +func (m *mockCollector) RecordDelete(cacheName string) { m.Called(cacheName) } + +func (m *mockCollector) RecordClear(cacheName string) { m.Called(cacheName) } + +// setupCache initializes the metrics and cache. +// If withCollector is true, it sets up a cache with a custom metrics collector. +// Otherwise, it sets up a cache without a custom metrics collector. +func setupCache[T any](t *testing.T, withCollector bool) (*Cache[T], *mockCollector) { + t.Helper() + + var collector *mockCollector + var c *Cache[T] + var err error + + if withCollector { + collector = new(mockCollector) + c, err = NewCache[T]("test_cache", WithMetricsCollector[T](collector)) + } else { + c, err = NewCache[T]("test_cache") + } + + assert.NoError(t, err, "Failed to create cache") + assert.NotNil(t, c, "Cache should not be nil") + + return c, collector +} + +func TestNewLRUCache(t *testing.T) { + t.Run("default configuration", func(t *testing.T) { + c, _ := setupCache[int](t, false) + assert.Equal(t, "test_cache", c.cacheName) + assert.NotNil(t, c.metrics, "Cache metrics should not be nil") + }) + + t.Run("with custom max cost", func(t *testing.T) { + c, _ := setupCache[int](t, false) + assert.NotNil(t, c) + }) + + t.Run("with metrics collector", func(t *testing.T) { + c, collector := setupCache[int](t, true) + assert.NotNil(t, c) + assert.Equal(t, "test_cache", c.cacheName) + assert.Equal(t, collector, c.metrics, "Cache metrics should match the collector") + }) +} + +func TestCacheSet(t *testing.T) { + c, collector := setupCache[string](t, true) + + collector.On("RecordSet", "test_cache").Once() + c.Set("key", "value") + + collector.AssertCalled(t, "RecordSet", "test_cache") +} + +func TestCacheGet(t *testing.T) { + c, collector := setupCache[string](t, true) + + collector.On("RecordSet", "test_cache").Once() + collector.On("RecordHit", "test_cache").Once() + collector.On("RecordMiss", "test_cache").Once() + + c.Set("key", "value") + collector.AssertCalled(t, "RecordSet", "test_cache") + + value, found := c.Get("key") + assert.True(t, found, "Expected to find the key") + assert.Equal(t, "value", value, "Expected value to match") + collector.AssertCalled(t, "RecordHit", "test_cache") + + _, found = c.Get("non_existent") + assert.False(t, found, "Expected not to find the key") + collector.AssertCalled(t, "RecordMiss", "test_cache") +} + +func TestCacheExists(t *testing.T) { + c, collector := setupCache[string](t, true) + + collector.On("RecordSet", "test_cache").Once() + collector.On("RecordHit", "test_cache").Twice() + collector.On("RecordMiss", "test_cache").Once() + + c.Set("key", "value") + collector.AssertCalled(t, "RecordSet", "test_cache") + + exists := c.Exists("key") + assert.True(t, exists, "Expected the key to exist") + collector.AssertCalled(t, "RecordHit", "test_cache") + + exists = c.Exists("non_existent") + assert.False(t, exists, "Expected the key not to exist") + collector.AssertCalled(t, "RecordMiss", "test_cache") +} + +func TestCacheDelete(t *testing.T) { + c, collector := setupCache[string](t, true) + + collector.On("RecordSet", "test_cache").Once() + collector.On("RecordDelete", "test_cache").Once() + collector.On("RecordMiss", "test_cache").Once() + collector.On("RecordEviction", "test_cache").Once() + + c.Set("key", "value") + collector.AssertCalled(t, "RecordSet", "test_cache") + + c.Delete("key") + collector.AssertCalled(t, "RecordDelete", "test_cache") + collector.AssertCalled(t, "RecordEviction", "test_cache") + + _, found := c.Get("key") + assert.False(t, found, "Expected not to find the deleted key") + collector.AssertCalled(t, "RecordMiss", "test_cache") +} + +func TestCacheClear(t *testing.T) { + c, collector := setupCache[string](t, true) + + collector.On("RecordSet", "test_cache").Twice() + collector.On("RecordClear", "test_cache").Once() + collector.On("RecordMiss", "test_cache").Twice() + collector.On("RecordEviction", "test_cache").Twice() + + c.Set("key1", "value1") + c.Set("key2", "value2") + collector.AssertNumberOfCalls(t, "RecordSet", 2) + + c.Clear() + collector.AssertCalled(t, "RecordClear", "test_cache") + collector.AssertNumberOfCalls(t, "RecordEviction", 2) + + _, found1 := c.Get("key1") + _, found2 := c.Get("key2") + assert.False(t, found1, "Expected not to find key1 after clear") + assert.False(t, found2, "Expected not to find key2 after clear") + collector.AssertNumberOfCalls(t, "RecordMiss", 2) +} diff --git a/pkg/cache/lru/metrics.go b/pkg/cache/lru/metrics.go new file mode 100644 index 000000000000..3085a4e3a9a6 --- /dev/null +++ b/pkg/cache/lru/metrics.go @@ -0,0 +1,41 @@ +package lru + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/trufflesecurity/trufflehog/v3/pkg/cache" +) + +// MetricsCollector should implement the collector interface. +var _ collector = (*MetricsCollector)(nil) + +// MetricsCollector extends the BaseMetricsCollector with Sized LRU specific metrics. +// It provides methods to record cache evictions. +type MetricsCollector struct { + // BaseMetricsCollector is embedded to provide the base metrics functionality. + cache.BaseMetricsCollector + + totalEvicts *prometheus.CounterVec +} + +// NewSizedLRUMetricsCollector initializes a new MetricsCollector with the provided namespace and subsystem. +func NewSizedLRUMetricsCollector(namespace, subsystem string) *MetricsCollector { + base := cache.GetMetricsCollector() + + totalEvicts := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "evictions_total", + Help: "Total number of cache evictions.", + }, []string{"cache_name"}) + + return &MetricsCollector{ + BaseMetricsCollector: base, + totalEvicts: totalEvicts, + } +} + +// RecordEviction increments the total number of cache evictions for the specified cache. +func (c *MetricsCollector) RecordEviction(cacheName string) { + c.totalEvicts.WithLabelValues(cacheName).Inc() +} diff --git a/pkg/cache/metrics.go b/pkg/cache/metrics.go new file mode 100644 index 000000000000..36a6ba86730e --- /dev/null +++ b/pkg/cache/metrics.go @@ -0,0 +1,108 @@ +package cache + +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/trufflesecurity/trufflehog/v3/pkg/common" +) + +// BaseMetricsCollector defines the interface for recording cache metrics. +// Each method corresponds to a specific cache-related operation. +type BaseMetricsCollector interface { + RecordHit(cacheName string) + RecordMiss(cacheName string) + RecordSet(cacheName string) + RecordDelete(cacheName string) + RecordClear(cacheName string) +} + +// MetricsCollector encapsulates all Prometheus metrics with labels. +// It holds Prometheus counters for cache operations, which help track +// the performance and usage of the cache. +type MetricsCollector struct { + // Base metrics. + hits *prometheus.CounterVec + misses *prometheus.CounterVec + sets *prometheus.CounterVec + deletes *prometheus.CounterVec + clears *prometheus.CounterVec +} + +func init() { + // Initialize the singleton MetricsCollector. + // Set up Prometheus counters for cache operations (hits, misses, sets, deletes, clears). + collectorOnce.Do(func() { + collector = &MetricsCollector{ + hits: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "hits_total", + Help: "Total number of cache hits.", + }, []string{"cache_name"}), + + misses: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "misses_total", + Help: "Total number of cache misses.", + }, []string{"cache_name"}), + + sets: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "sets_total", + Help: "Total number of cache set operations.", + }, []string{"cache_name"}), + + deletes: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "deletes_total", + Help: "Total number of cache delete operations.", + }, []string{"cache_name"}), + + clears: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "clears_total", + Help: "Total number of cache clear operations.", + }, []string{"cache_name"}), + } + }) +} + +var ( + collectorOnce sync.Once // Ensures that the collector is initialized only once. + collector *MetricsCollector +) + +// GetMetricsCollector returns the singleton MetricsCollector instance. +// It panics if InitializeMetrics has not been called to ensure metrics are properly initialized. +// Must be called after InitializeMetrics to avoid runtime issues. +// If you do it before, BAD THINGS WILL HAPPEN. +func GetMetricsCollector() *MetricsCollector { + if collector == nil { + panic("MetricsCollector not initialized. Call InitializeMetrics first.") + } + return collector +} + +// Implement BaseMetricsCollector interface methods. + +// RecordHit increments the counter for cache hits, tracking how often cache lookups succeed. +func (m *MetricsCollector) RecordHit(cacheName string) { m.hits.WithLabelValues(cacheName).Inc() } + +// RecordMiss increments the counter for cache misses, tracking how often cache lookups fail. +func (m *MetricsCollector) RecordMiss(cacheName string) { m.misses.WithLabelValues(cacheName).Inc() } + +// RecordSet increments the counter for cache set operations, tracking how often items are added/updated. +func (m *MetricsCollector) RecordSet(cacheName string) { m.sets.WithLabelValues(cacheName).Inc() } + +// RecordDelete increments the counter for cache delete operations, tracking how often items are removed. +func (m *MetricsCollector) RecordDelete(cacheName string) { m.deletes.WithLabelValues(cacheName).Inc() } + +// RecordClear increments the counter for cache clear operations, tracking how often the cache is completely cleared. +func (m *MetricsCollector) RecordClear(cacheName string) { m.clears.WithLabelValues(cacheName).Inc() }