Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dmachard committed Dec 6, 2024
1 parent 2b24b62 commit cafb880
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 31 deletions.
12 changes: 8 additions & 4 deletions docs/transformers/transform_newdomaintracker.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ The **New Domain Tracker** transformer identifies domains that are newly observe

- **Configurable Time Window**: Define how long a domain is considered new.
- **LRU-based Memory Management**: Ensures efficient memory usage with a finite cache size.

- **Persistence**: Optionally save the domain cache to disk for continuity after restarts.
- **Whitelist Support**: Exclude specific domains or patterns from detection.

Expand Down Expand Up @@ -35,6 +34,7 @@ transforms:
ttl: 3600
cache-size: 100000
white-domains-file: ""
persistence-file: ""
```
## Cache
Expand All @@ -50,8 +50,6 @@ Example of configuration to load a whitelist of domains to ignore.
```yaml
transforms:
new-domain-tracker:
ttl: 3600
cache-size: 100000
white-domains-file: /tmp/whitelist_domain.txt
```
Expand All @@ -64,5 +62,11 @@ github.com
## Persistence
To ensure continuity across application restarts, you can enable the persistence feature by specifying a file path (persistence). The transformer will save the domain cache to this file and reload it on startup.
To ensure continuity across application restarts, you can enable the persistence feature by specifying a file path (persistence).
The transformer will save the domain cache to this file and reload it on startup.
```yaml
transforms:
new-domain-tracker:
persistence-file: /tmp/nod-state.json
```
1 change: 1 addition & 0 deletions pkgconfig/transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ type ConfigTransformers struct {
TTL int `yaml:"ttl" default:"3600"`
CacheSize int `yaml:"cache-size" default:"100000"`
WhiteDomainsFile string `yaml:"white-domains-file" default:""`
PersistenceFile string `yaml:"persistence-file" default:""`
} `yaml:"new-domain-tracker"`
}

Expand Down
102 changes: 82 additions & 20 deletions transformers/newdomain.go → transformers/newdomaintracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package transformers

import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
Expand All @@ -14,17 +16,16 @@ import (
lru "github.com/hashicorp/golang-lru"
)

// NewDomainTracker transformer to detect newly observed domains
type NewDomainTracker struct {
ttl time.Duration // Time window to consider a domain as "new"
cache *lru.Cache // LRU Cache to store observed domains
whitelist map[string]*regexp.Regexp // Whitelisted domains
logInfo func(msg string, v ...interface{})
logError func(msg string, v ...interface{})
ttl time.Duration // Time window to consider a domain as "new"
cache *lru.Cache // LRU Cache to store observed domains
whitelist map[string]*regexp.Regexp // Whitelisted domains
persistencePath string
logInfo func(msg string, v ...interface{})
logError func(msg string, v ...interface{})
}

// NewNewDomainTracker initializes the NewDomainTracker transformer
func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) {
func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*regexp.Regexp, persistencePath string, logInfo, logError func(msg string, v ...interface{})) (*NewDomainTracker, error) {
cache, err := lru.New(maxSize)
if err != nil {
return nil, err
Expand All @@ -34,16 +35,24 @@ func NewNewDomainTracker(ttl time.Duration, maxSize int, whitelist map[string]*r
return nil, fmt.Errorf("invalid TTL value: %v", ttl)
}

return &NewDomainTracker{
ttl: ttl,
cache: cache,
whitelist: whitelist,
logInfo: logInfo,
logError: logError,
}, nil
tracker := &NewDomainTracker{
ttl: ttl,
cache: cache,
whitelist: whitelist,
persistencePath: persistencePath,
logInfo: logInfo,
logError: logError,
}
// Load cache state from disk if persistence is enabled
if persistencePath != "" {
if err := tracker.loadCacheFromDisk(); err != nil {
return nil, fmt.Errorf("failed to load cache state: %v", err)

Check failure on line 49 in transformers/newdomaintracker.go

View workflow job for this annotation

GitHub Actions / linter

non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint)
}
}

return tracker, nil
}

// isWhitelisted checks if a domain or its subdomain is in the whitelist
func (ndt *NewDomainTracker) isWhitelisted(domain string) bool {
for _, d := range ndt.whitelist {
if d.MatchString(domain) {
Expand All @@ -53,7 +62,6 @@ func (ndt *NewDomainTracker) isWhitelisted(domain string) bool {
return false
}

// IsNewDomain checks if the domain is newly observed
func (ndt *NewDomainTracker) IsNewDomain(domain string) bool {
// Check if the domain is whitelisted
if ndt.isWhitelisted(domain) {
Expand All @@ -64,9 +72,7 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool {

// Check if the domain exists in the cache
if lastSeen, exists := ndt.cache.Get(domain); exists {
fmt.Println("exists")
if now.Sub(lastSeen.(time.Time)) < ndt.ttl {
fmt.Println(now.Sub(lastSeen.(time.Time)), ndt.ttl)
// Domain was recently seen, not new
return false
}
Expand All @@ -77,6 +83,48 @@ func (ndt *NewDomainTracker) IsNewDomain(domain string) bool {
return true
}

func (ndt *NewDomainTracker) SaveCacheToDisk() error {
state := make(map[string]time.Time)
for _, key := range ndt.cache.Keys() {
if value, ok := ndt.cache.Peek(key); ok {
state[key.(string)] = value.(time.Time)
}
}

data, err := json.Marshal(state)
if err != nil {
return err
}

return os.WriteFile(ndt.persistencePath, data, 0644)
}

// loadCacheFromDisk loads the cache state from a file
func (ndt *NewDomainTracker) loadCacheFromDisk() error {
if ndt.persistencePath == "" {
return errors.New("persistence filepath not set")
}

data, err := os.ReadFile(ndt.persistencePath)
if err != nil {
if os.IsNotExist(err) {
return nil // File does not exist, no previous state to load
}
return err
}

state := make(map[string]time.Time)
if err := json.Unmarshal(data, &state); err != nil {
return err
}

for domain, lastSeen := range state {
ndt.cache.Add(domain, lastSeen)
}

return nil
}

// NewDomainTransform is the Transformer for DNS messages
type NewDomainTrackerTransform struct {
GenericTransformer
Expand Down Expand Up @@ -110,7 +158,7 @@ func (t *NewDomainTrackerTransform) GetTransforms() ([]Subtransform, error) {
// Initialize the domain tracker
ttl := time.Duration(t.config.NewDomainTracker.TTL) * time.Second
maxSize := t.config.NewDomainTracker.CacheSize
tracker, err := NewNewDomainTracker(ttl, maxSize, t.listDomainsRegex, t.LogInfo, t.LogError)
tracker, err := NewNewDomainTracker(ttl, maxSize, t.listDomainsRegex, t.config.NewDomainTracker.PersistenceFile, t.LogInfo, t.LogError)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -146,9 +194,23 @@ func (t *NewDomainTrackerTransform) LoadWhiteDomainsList() error {

// Process processes DNS messages and detects newly observed domains
func (t *NewDomainTrackerTransform) trackNewDomain(dm *dnsutils.DNSMessage) (int, error) {
// Log a warning if the cache is full (before adding the new domain)
if t.domainTracker.cache.Len() == t.config.NewDomainTracker.CacheSize {
return ReturnError, fmt.Errorf("LRU cache is full. Consider increasing cache-size to avoid frequent evictions")
}

// Check if the domain is newly observed
if t.domainTracker.IsNewDomain(dm.DNS.Qname) {
return ReturnKeep, nil
}
return ReturnDrop, nil
}

func (t *NewDomainTrackerTransform) Reset() {
if len(t.domainTracker.persistencePath) != 0 {
if err := t.domainTracker.SaveCacheToDisk(); err != nil {
t.LogError("failed to save cache state: %v", err)
}
t.LogInfo("cache content saved on disk with success")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,40 @@ func TestNewDomainTracker_Whitelist(t *testing.T) {
t.Errorf("2. this domain should be new!!")
}
}

// func TestNewDomainTracker_LRUCacheFull(t *testing.T) {
// // config
// config := pkgconfig.GetFakeConfigTransformers()
// config.NewDomainTracker.Enable = true
// config.NewDomainTracker.TTL = 2
// config.NewDomainTracker.CacheSize = 1

// outChans := []chan dnsutils.DNSMessage{}

// // init subproccesor
// tracker := NewNewDomainTrackerTransform(config, logger.New(false), "test", 0, outChans)

// // init transforms
// _, err := tracker.GetTransforms()
// if err != nil {
// t.Error("fail to init transform", err)
// }

// // first send
// dm := dnsutils.GetFakeDNSMessage()
// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
// t.Errorf("This domain should be new!")
// }

// if result, _ := tracker.trackNewDomain(&dm); result != ReturnError {
// t.Errorf("LRU Full")
// }

// // wait ttl for expiration
// time.Sleep(3 * time.Second)

// // recheck
// if result, _ := tracker.trackNewDomain(&dm); result != ReturnKeep {
// t.Errorf("recheck, this domain should be new!!")
// }
// }
21 changes: 14 additions & 7 deletions transformers/transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

var (
ReturnKeep = 1
ReturnDrop = 2
ReturnError = 0
ReturnKeep = 1
ReturnDrop = 2
)

type Subtransform struct {
Expand Down Expand Up @@ -63,8 +64,9 @@ type Transforms struct {
name string
instance int

availableTransforms []TransformEntry
activeTransforms []func(dm *dnsutils.DNSMessage) (int, error)
availableTransforms []TransformEntry
activeTransforms []TransformEntry
activeProcessTransforms []func(dm *dnsutils.DNSMessage) (int, error)
}

func NewTransforms(config *pkgconfig.ConfigTransformers, logger *logger.Logger, name string, nextWorkers []chan dnsutils.DNSMessage, instance int) Transforms {
Expand Down Expand Up @@ -102,6 +104,7 @@ func (p *Transforms) ReloadConfig(config *pkgconfig.ConfigTransformers) {

func (p *Transforms) Prepare() error {
// clean the slice
p.activeProcessTransforms = p.activeProcessTransforms[:0]
p.activeTransforms = p.activeTransforms[:0]
tranformsList := []string{}

Expand All @@ -111,8 +114,12 @@ func (p *Transforms) Prepare() error {
p.LogError("error on init subtransforms: %v", err)
continue
}
if len(subtransforms) > 0 {
p.activeTransforms = append(p.activeTransforms, transform)
}
for _, subtransform := range subtransforms {
p.activeTransforms = append(p.activeTransforms, subtransform.processFunc)
p.activeProcessTransforms = append(p.activeProcessTransforms, subtransform.processFunc)

tranformsList = append(tranformsList, subtransform.name)
}
}
Expand All @@ -124,7 +131,7 @@ func (p *Transforms) Prepare() error {
}

func (p *Transforms) Reset() {
for _, transform := range p.availableTransforms {
for _, transform := range p.activeTransforms {
transform.Reset()
}
}
Expand All @@ -139,7 +146,7 @@ func (p *Transforms) LogError(msg string, v ...interface{}) {
}

func (p *Transforms) ProcessMessage(dm *dnsutils.DNSMessage) (int, error) {
for _, transform := range p.activeTransforms {
for _, transform := range p.activeProcessTransforms {
if result, err := transform(dm); err != nil {
return ReturnKeep, fmt.Errorf("error on transform processing: %v", err.Error())
} else if result == ReturnDrop {
Expand Down
1 change: 1 addition & 0 deletions workers/dnsmessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ func (w *DNSMessage) StartCollect() {
for {
select {
case <-w.OnStop():
subprocessors.Reset()
return

// save the new config
Expand Down

0 comments on commit cafb880

Please sign in to comment.