Skip to content

Commit

Permalink
use expirable lru cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dmachard committed Dec 6, 2024
1 parent cafb880 commit 99fd81a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 66 deletions.
48 changes: 18 additions & 30 deletions transformers/newdomaintracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,26 @@ import (
"github.com/dmachard/go-dnscollector/dnsutils"
"github.com/dmachard/go-dnscollector/pkgconfig"
"github.com/dmachard/go-logger"
lru "github.com/hashicorp/golang-lru"
"github.com/hashicorp/golang-lru/v2/expirable"
)

type NewDomainTracker struct {
ttl time.Duration // Time window to consider a domain as "new"
cache *lru.Cache // LRU Cache to store observed domains
whitelist map[string]*regexp.Regexp // Whitelisted domains
ttl time.Duration // Time window to consider a domain as "new"
cache *expirable.LRU[string, struct{}] // Expirable LRU Cache
whitelist map[string]*regexp.Regexp // Whitelisted domains
persistencePath string
logInfo func(msg string, v ...interface{})
logError func(msg string, v ...interface{})
}

func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, persistencePath string, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) {
cache, err := lru.New(maxSize)
if err != nil {
return nil, err
}

if ttl <= 0 {
return nil, fmt.Errorf("invalid TTL value: %v", ttl)
}

cache := expirable.NewLRU[string, struct{}](maxSize, nil, ttl)

tracker := &NewDomainTracker{
ttl: ttl,
cache: cache,
Expand All @@ -46,7 +44,7 @@ func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*r
// Load cache state from disk if persistence is enabled
if persistencePath != "" {
if err := tracker.loadCacheFromDisk(); err != nil {
return nil, fmt.Errorf("failed to load cache state: %v", err)
return nil, fmt.Errorf("failed to load cache state: %w", err)
}
}

Expand All @@ -68,30 +66,20 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool {
return false
}

now := time.Now()

// Check if the domain exists in the cache
if lastSeen, exists := ndt.cache.Get(domain); exists {
if now.Sub(lastSeen.(time.Time)) < ndt.ttl {
// Domain was recently seen, not new
return false
}
if _, exists := ndt.cache.Get(domain); exists {
// Domain was recently seen, not new
return false
}

// Otherwise, mark the domain as new and update the cache
ndt.cache.Add(domain, now)
// Otherwise, mark the domain as new
ndt.cache.Add(domain, struct{}{})
return true
}

func (ndt *NewDomainTracker) SaveCacheToDisk() error {
state := make(map[string]time.Time)
for _, key := range ndt.cache.Keys() {
if value, ok := ndt.cache.Peek(key); ok {
state[key.(string)] = value.(time.Time)
}
}

data, err := json.Marshal(state)
keys := ndt.cache.Keys()
data, err := json.Marshal(keys)
if err != nil {
return err
}
Expand All @@ -113,13 +101,13 @@ func (ndt *NewDomainTracker) loadCacheFromDisk() error {
return err
}

state := make(map[string]time.Time)
if err := json.Unmarshal(data, &state); err != nil {
var keys []string
if err := json.Unmarshal(data, &keys); err != nil {
return err
}

for domain, lastSeen := range state {
ndt.cache.Add(domain, lastSeen)
for _, key := range keys {
ndt.cache.Add(key, struct{}{})
}

return nil
Expand Down
74 changes: 38 additions & 36 deletions transformers/newdomaintracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,41 @@ func TestNewDomainTracker_Whitelist(t *testing.T) {
}
}

// func TestNewDomainTracker_LRUCacheFull(t *testing.T) {
// // config
// config := pkgconfig.GetFakeConfigTransformers()
// config.NewDomainTracker.Enable = true
// config.NewDomainTracker.TTL = 2
// config.NewDomainTracker.CacheSize = 1

// outChans := []chan dnsutils.DNSMessage{}

// // init subproccesor
// tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans)

// // init transforms
// _, err := tracker.GetTransforms()
// if err != nil {
// t.Error("fail to init transform", err)
// }

// // first send
// dm := dnsutils.GetFakeDNSMessage()
// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
// t.Errorf("This domain should be new!")
// }

// if result, _ := tracker.trackNewDomain(&dm); result != ReturnError {
// t.Errorf("LRU Full")
// }

// // wait ttl for expiration
// time.Sleep(3 * time.Second)

// // recheck
// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
// t.Errorf("recheck, this domain should be new!!")
// }
// }
func TestNewDomainTracker_LRUCacheFull(t *testing.T) {
// config
config := pkgconfig.GetFakeConfigTransformers()
config.NewDomainTracker.Enable = true
config.NewDomainTracker.TTL = 2
config.NewDomainTracker.CacheSize = 1

outChans := []chan dnsutils.DNSMessage{}

// init subproccesor
tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans)

// init transforms
_, err := tracker.GetTransforms()
if err != nil {
t.Error("fail to init transform", err)
}

// Send the first domain
dm := dnsutils.GetFakeDNSMessage()
if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
t.Errorf("This domain should be new!")
}

// Send the same domain again (should return an error because cache is full)
result, _ := tracker.trackNewDomain(&dm)
if result != ReturnError {
t.Errorf("Cache full check failed, expected ReturnError")
}

// Wait for TTL expiration
time.Sleep(4 * time.Second)

// Retry the domain after TTL expiration (should be considered new again)
if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
t.Errorf("recheck, this domain should be new!!")
}
}

0 comments on commit 99fd81a

Please sign in to comment.