From cafb880a5add7dcf8e99a71170aada7b0ccb5e43 Mon Sep 17 00:00:00 2001 From: dmachard <5562930+dmachard@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:13:42 +0100 Subject: [PATCH] fix --- .../transform_newdomaintracker.md | 12 ++- pkgconfig/transformers.go | 1 + .../{newdomain.go => newdomaintracker.go} | 102 ++++++++++++++---- ...omain_test.go => newdomaintracker_test.go} | 37 +++++++ transformers/transformers.go | 21 ++-- workers/dnsmessage.go | 1 + 6 files changed, 143 insertions(+), 31 deletions(-) rename transformers/{newdomain.go => newdomaintracker.go} (60%) rename transformers/{newdomain_test.go => newdomaintracker_test.go} (67%) diff --git a/docs/transformers/transform_newdomaintracker.md b/docs/transformers/transform_newdomaintracker.md index 93be1bca..b31ffdf2 100644 --- a/docs/transformers/transform_newdomaintracker.md +++ b/docs/transformers/transform_newdomaintracker.md @@ -6,7 +6,6 @@ The **New Domain Tracker** transformer identifies domains that are newly observe - **Configurable Time Window**: Define how long a domain is considered new. - **LRU-based Memory Management**: Ensures efficient memory usage with a finite cache size. - - **Persistence**: Optionally save the domain cache to disk for continuity after restarts. - **Whitelist Support**: Exclude specific domains or patterns from detection. @@ -35,6 +34,7 @@ transforms: ttl: 3600 cache-size: 100000 white-domains-file: "" + persistence-file: "" ``` ## Cache @@ -50,8 +50,6 @@ Example of configuration to load a whitelist of domains to ignore. ```yaml transforms: new-domain-tracker: - ttl: 3600 - cache-size: 100000 white-domains-file: /tmp/whitelist_domain.txt ``` @@ -64,5 +62,11 @@ github.com ## Persistence -To ensure continuity across application restarts, you can enable the persistence feature by specifying a file path (persistence). The transformer will save the domain cache to this file and reload it on startup. +To ensure continuity across application restarts, you can enable the persistence feature by specifying a file path (persistence). +The transformer will save the domain cache to this file and reload it on startup. +```yaml +transforms: + new-domain-tracker: + persistence-file: /tmp/nod-state.json +``` \ No newline at end of file diff --git a/pkgconfig/transformers.go b/pkgconfig/transformers.go index 593f63ce..a6c83dc3 100644 --- a/pkgconfig/transformers.go +++ b/pkgconfig/transformers.go @@ -100,6 +100,7 @@ type ConfigTransformers struct { TTL int `yaml:"ttl" default:"3600"` CacheSize int `yaml:"cache-size" default:"100000"` WhiteDomainsFile string `yaml:"white-domains-file" default:""` + PersistenceFile string `yaml:"persistence-file" default:""` } `yaml:"new-domain-tracker"` } diff --git a/transformers/newdomain.go b/transformers/newdomaintracker.go similarity index 60% rename from transformers/newdomain.go rename to transformers/newdomaintracker.go index 51f17159..6c4d9d03 100644 --- a/transformers/newdomain.go +++ b/transformers/newdomaintracker.go @@ -2,6 +2,8 @@ package transformers import ( "bufio" + "encoding/json" + "errors" "fmt" "os" "regexp" @@ -14,17 +16,16 @@ import ( lru "github.com/hashicorp/golang-lru" ) -// NewDomainTracker transformer to detect newly observed domains type NewDomainTracker struct { - ttl time.Duration // Time window to consider a domain as "new" - cache *lru.Cache // LRU Cache to store observed domains - whitelist map[string]*regexp.Regexp // Whitelisted domains - logInfo func(msg string, v ...interface{}) - logError func(msg string, v ...interface{}) + ttl time.Duration // Time window to consider a domain as "new" + cache *lru.Cache // LRU Cache to store observed domains + whitelist map[string]*regexp.Regexp // Whitelisted domains + persistencePath string + logInfo func(msg string, v ...interface{}) + logError func(msg string, v ...interface{}) } -// NewNewDomainTracker initializes the NewDomainTracker transformer -func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) { +func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, persistencePath string, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) { cache, err := lru.New(maxSize) if err != nil { return nil, err @@ -34,16 +35,24 @@ func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*r return nil, fmt.Errorf("invalid TTL value: %v", ttl) } - return &NewDomainTracker{ - ttl: ttl, - cache: cache, - whitelist: whitelist, - logInfo: logInfo, - logError: logError, - }, nil + tracker := &NewDomainTracker{ + ttl: ttl, + cache: cache, + whitelist: whitelist, + persistencePath: persistencePath, + logInfo: logInfo, + logError: logError, + } + // Load cache state from disk if persistence is enabled + if persistencePath != "" { + if err := tracker.loadCacheFromDisk(); err != nil { + return nil, fmt.Errorf("failed to load cache state: %v", err) + } + } + + return tracker, nil } -// isWhitelisted checks if a domain or its subdomain is in the whitelist func (ndt *NewDomainTracker) isWhitelisted(domain string) bool { for _, d := range ndt.whitelist { if d.MatchString(domain) { @@ -53,7 +62,6 @@ func (ndt *NewDomainTracker) isWhitelisted(domain string) bool { return false } -// IsNewDomain checks if the domain is newly observed func (ndt *NewDomainTracker) IsNewDomain(domain string) bool { // Check if the domain is whitelisted if ndt.isWhitelisted(domain) { @@ -64,9 +72,7 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool { // Check if the domain exists in the cache if lastSeen, exists := ndt.cache.Get(domain); exists { - fmt.Println("exists") if now.Sub(lastSeen.(time.Time)) < ndt.ttl { - fmt.Println(now.Sub(lastSeen.(time.Time)), ndt.ttl) // Domain was recently seen, not new return false } @@ -77,6 +83,48 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool { return true } +func (ndt *NewDomainTracker) SaveCacheToDisk() error { + state := make(map[string]time.Time) + for _, key := range ndt.cache.Keys() { + if value, ok := ndt.cache.Peek(key); ok { + state[key.(string)] = value.(time.Time) + } + } + + data, err := json.Marshal(state) + if err != nil { + return err + } + + return os.WriteFile(ndt.persistencePath, data, 0644) +} + +// loadCacheFromDisk loads the cache state from a file +func (ndt *NewDomainTracker) loadCacheFromDisk() error { + if ndt.persistencePath == "" { + return errors.New("persistence filepath not set") + } + + data, err := os.ReadFile(ndt.persistencePath) + if err != nil { + if os.IsNotExist(err) { + return nil // File does not exist, no previous state to load + } + return err + } + + state := make(map[string]time.Time) + if err := json.Unmarshal(data, &state); err != nil { + return err + } + + for domain, lastSeen := range state { + ndt.cache.Add(domain, lastSeen) + } + + return nil +} + // NewDomainTransform is the Transformer for DNS messages type NewDomainTrackerTransform struct { GenericTransformer @@ -110,7 +158,7 @@ func (t *NewDomainTrackerTransform) GetTransforms() ([]Subtransform, error) { // Initialize the domain tracker ttl := time.Duration(t.config.NewDomainTracker.TTL) * time.Second maxSize := t.config.NewDomainTracker.CacheSize - tracker, err := NewNewDomainTracker(ttl, maxSize, t.listDomainsRegex, t.LogInfo, t.LogError) + tracker, err := NewNewDomainTracker(ttl, maxSize, t.listDomainsRegex, t.config.NewDomainTracker.PersistenceFile, t.LogInfo, t.LogError) if err != nil { return nil, err } @@ -146,9 +194,23 @@ func (t *NewDomainTrackerTransform) LoadWhiteDomainsList() error { // Process processes DNS messages and detects newly observed domains func (t *NewDomainTrackerTransform) trackNewDomain(dm *dnsutils.DNSMessage) (int, error) { + // Log a warning if the cache is full (before adding the new domain) + if t.domainTracker.cache.Len() == t.config.NewDomainTracker.CacheSize { + return ReturnError, fmt.Errorf("LRU cache is full. Consider increasing cache-size to avoid frequent evictions") + } + // Check if the domain is newly observed if t.domainTracker.IsNewDomain(dm.DNS.Qname) { return ReturnKeep, nil } return ReturnDrop, nil } + +func (t *NewDomainTrackerTransform) Reset() { + if len(t.domainTracker.persistencePath) != 0 { + if err := t.domainTracker.SaveCacheToDisk(); err != nil { + t.LogError("failed to save cache state: %v", err) + } + t.LogInfo("cache content saved on disk with success") + } +} diff --git a/transformers/newdomain_test.go b/transformers/newdomaintracker_test.go similarity index 67% rename from transformers/newdomain_test.go rename to transformers/newdomaintracker_test.go index fddf67e3..5505b005 100644 --- a/transformers/newdomain_test.go +++ b/transformers/newdomaintracker_test.go @@ -74,3 +74,40 @@ func TestNewDomainTracker_Whitelist(t *testing.T) { t.Errorf("2. this domain should be new!!") } } + +// func TestNewDomainTracker_LRUCacheFull(t *testing.T) { +// // config +// config := pkgconfig.GetFakeConfigTransformers() +// config.NewDomainTracker.Enable = true +// config.NewDomainTracker.TTL = 2 +// config.NewDomainTracker.CacheSize = 1 + +// outChans := []chan dnsutils.DNSMessage{} + +// // init subproccesor +// tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans) + +// // init transforms +// _, err := tracker.GetTransforms() +// if err != nil { +// t.Error("fail to init transform", err) +// } + +// // first send +// dm := dnsutils.GetFakeDNSMessage() +// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { +// t.Errorf("This domain should be new!") +// } + +// if result, _ := tracker.trackNewDomain(&dm); result != ReturnError { +// t.Errorf("LRU Full") +// } + +// // wait ttl for expiration +// time.Sleep(3 * time.Second) + +// // recheck +// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep { +// t.Errorf("recheck, this domain should be new!!") +// } +// } diff --git a/transformers/transformers.go b/transformers/transformers.go index ef422cf5..808bcc67 100644 --- a/transformers/transformers.go +++ b/transformers/transformers.go @@ -9,8 +9,9 @@ import ( ) var ( - ReturnKeep = 1 - ReturnDrop = 2 + ReturnError = 0 + ReturnKeep = 1 + ReturnDrop = 2 ) type Subtransform struct { @@ -63,8 +64,9 @@ type Transforms struct { name string instance int - availableTransforms []TransformEntry - activeTransforms []func(dm *dnsutils.DNSMessage) (int, error) + availableTransforms []TransformEntry + activeTransforms []TransformEntry + activeProcessTransforms []func(dm *dnsutils.DNSMessage) (int, error) } func NewTransforms(config *pkgconfig.ConfigTransformers, logger *logger.Logger, name string, nextWorkers []chan dnsutils.DNSMessage, instance int) Transforms { @@ -102,6 +104,7 @@ func (p *Transforms) ReloadConfig(config *pkgconfig.ConfigTransformers) { func (p *Transforms) Prepare() error { // clean the slice + p.activeProcessTransforms = p.activeProcessTransforms[:0] p.activeTransforms = p.activeTransforms[:0] tranformsList := []string{} @@ -111,8 +114,12 @@ func (p *Transforms) Prepare() error { p.LogError("error on init subtransforms: %v", err) continue } + if len(subtransforms) > 0 { + p.activeTransforms = append(p.activeTransforms, transform) + } for _, subtransform := range subtransforms { - p.activeTransforms = append(p.activeTransforms, subtransform.processFunc) + p.activeProcessTransforms = append(p.activeProcessTransforms, subtransform.processFunc) + tranformsList = append(tranformsList, subtransform.name) } } @@ -124,7 +131,7 @@ func (p *Transforms) Prepare() error { } func (p *Transforms) Reset() { - for _, transform := range p.availableTransforms { + for _, transform := range p.activeTransforms { transform.Reset() } } @@ -139,7 +146,7 @@ func (p *Transforms) LogError(msg string, v ...interface{}) { } func (p *Transforms) ProcessMessage(dm *dnsutils.DNSMessage) (int, error) { - for _, transform := range p.activeTransforms { + for _, transform := range p.activeProcessTransforms { if result, err := transform(dm); err != nil { return ReturnKeep, fmt.Errorf("error on transform processing: %v", err.Error()) } else if result == ReturnDrop { diff --git a/workers/dnsmessage.go b/workers/dnsmessage.go index d338e052..2641573f 100644 --- a/workers/dnsmessage.go +++ b/workers/dnsmessage.go @@ -182,6 +182,7 @@ func (w *DNSMessage) StartCollect() { for { select { case <-w.OnStop(): + subprocessors.Reset() return // save the new config