diff --git a/transformers/newdomaintracker.go b/transformers/newdomaintracker.go index 6c4d9d03..6e2e80f7 100644 --- a/transformers/newdomaintracker.go +++ b/transformers/newdomaintracker.go @@ -13,28 +13,26 @@ import ( "github.com/dmachard/go-dnscollector/dnsutils" "github.com/dmachard/go-dnscollector/pkgconfig" "github.com/dmachard/go-logger" - lru "github.com/hashicorp/golang-lru" + "github.com/hashicorp/golang-lru/v2/expirable" ) type NewDomainTracker struct { - ttl time.Duration // Time window to consider a domain as "new" - cache *lru.Cache // LRU Cache to store observed domains - whitelist map[string]*regexp.Regexp // Whitelisted domains + ttl time.Duration // Time window to consider a domain as "new" + cache *expirable.LRU[string, struct{}] // Expirable LRU Cache + whitelist map[string]*regexp.Regexp // Whitelisted domains persistencePath string logInfo func(msg string, v ...interface{}) logError func(msg string, v ...interface{}) } func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, persistencePath string, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) { - cache, err := lru.New(maxSize) - if err != nil { - return nil, err - } if ttl <= 0 { return nil, fmt.Errorf("invalid TTL value: %v", ttl) } + cache := expirable.NewLRU[string, struct{}](maxSize, nil, ttl) + tracker := &NewDomainTracker{ ttl: ttl, cache: cache, @@ -46,7 +44,7 @@ func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*r // Load cache state from disk if persistence is enabled if persistencePath != "" { if err := tracker.loadCacheFromDisk(); err != nil { - return nil, fmt.Errorf("failed to load cache state: %v", err) + return nil, fmt.Errorf("failed to load cache state: %w", err) } } @@ -68,30 +66,20 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool { return false } - now := time.Now() - // Check if the domain exists in the cache - if lastSeen, exists := ndt.cache.Get(domain); exists { - if now.Sub(lastSeen.(time.Time)) < ndt.ttl { - // Domain was recently seen, not new - return false - } + if _, exists := ndt.cache.Get(domain); exists { + // Domain was recently seen, not new + return false } - // Otherwise, mark the domain as new and update the cache - ndt.cache.Add(domain, now) + // Otherwise, mark the domain as new + ndt.cache.Add(domain, struct{}{}) return true } func (ndt *NewDomainTracker) SaveCacheToDisk() error { - state := make(map[string]time.Time) - for _, key := range ndt.cache.Keys() { - if value, ok := ndt.cache.Peek(key); ok { - state[key.(string)] = value.(time.Time) - } - } - - data, err := json.Marshal(state) + keys := ndt.cache.Keys() + data, err := json.Marshal(keys) if err != nil { return err } @@ -113,13 +101,13 @@ func (ndt *NewDomainTracker) loadCacheFromDisk() error { return err } - state := make(map[string]time.Time) - if err := json.Unmarshal(data, &state); err != nil { + var keys []string + if err := json.Unmarshal(data, &keys); err != nil { return err } - for domain, lastSeen := range state { - ndt.cache.Add(domain, lastSeen) + for _, key := range keys { + ndt.cache.Add(key, struct{}{}) } return nil diff --git a/transformers/newdomaintracker_test.go b/transformers/newdomaintracker_test.go index 5505b005..f63ade2e 100644 --- a/transformers/newdomaintracker_test.go +++ b/transformers/newdomaintracker_test.go @@ -75,39 +75,41 @@ func TestNewDomainTracker_Whitelist(t *testing.T) { } } -// func TestNewDomainTracker_LRUCacheFull(t *testing.T) { -// // config -// config := pkgconfig.GetFakeConfigTransformers() -// config.NewDomainTracker.Enable = true -// config.NewDomainTracker.TTL = 2 -// config.NewDomainTracker.CacheSize = 1 - -// outChans := []chan dnsutils.DNSMessage{} - -// // init subproccesor -// tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans) - -// // init transforms -// _, err := tracker.GetTransforms() -// if err != nil { -// t.Error("fail to init transform", err) -// } - -// // first send -// dm := dnsutils.GetFakeDNSMessage() -// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { -// t.Errorf("This domain should be new!") -// } - -// if result, _ := tracker.trackNewDomain(&dm); result != ReturnError { -// t.Errorf("LRU Full") -// } - -// // wait ttl for expiration -// time.Sleep(3 * time.Second) - -// // recheck -// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { -// t.Errorf("recheck, this domain should be new!!") -// } -// } +func TestNewDomainTracker_LRUCacheFull(t *testing.T) { + // config + config := pkgconfig.GetFakeConfigTransformers() + config.NewDomainTracker.Enable = true + config.NewDomainTracker.TTL = 2 + config.NewDomainTracker.CacheSize = 1 + + outChans := []chan dnsutils.DNSMessage{} + + // init subproccesor + tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans) + + // init transforms + _, err := tracker.GetTransforms() + if err != nil { + t.Error("fail to init transform", err) + } + + // Send the first domain + dm := dnsutils.GetFakeDNSMessage() + if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { + t.Errorf("This domain should be new!") + } + + // Send the same domain again (should return an error because cache is full) + result, _ := tracker.trackNewDomain(&dm) + if result != ReturnError { + t.Errorf("Cache full check failed, expected ReturnError") + } + + // Wait for TTL expiration + time.Sleep(4 * time.Second) + + // Retry the domain after TTL expiration (should be considered new again) + if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { + t.Errorf("recheck, this domain should be new!!") + } +}