Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid race conditions in DNSSnooper's cached domains #2637

Merged
merged 4 commits into from
Jun 23, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions probe/endpoint/dns_snooper_linux_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"math"
"sync"
"time"

log "github.com/Sirupsen/logrus"
Expand All @@ -29,6 +30,14 @@ type DNSSnooper struct {
decodingErrorCounts map[string]uint64 // for limiting
}

type domainSet struct {
// Not worth using a RMutex since we don't expect a lot of contention
// and they are considerably larger (8 bytes vs 24 bytes), which would
// bloat the cache
sync.Mutex

This comment was marked as abuse.

This comment was marked as abuse.

This comment was marked as abuse.

This comment was marked as abuse.

This comment was marked as abuse.

domains map[string]struct{}
}

This comment was marked as abuse.


// NewDNSSnooper creates a new snooper of DNS queries
func NewDNSSnooper() (*DNSSnooper, error) {
pcapHandle, err := newPcapHandle()
Expand Down Expand Up @@ -102,9 +111,12 @@ func (s *DNSSnooper) CachedNamesForIP(ip string) []string {
return result
}

for domain := range domains.(map[string]struct{}) {
domainSet := domains.(domainSet)
domainSet.Lock()
for domain := range domainSet.domains {
result = append(result, domain)
}
domainSet.Unlock()

return result
}
Expand Down Expand Up @@ -269,10 +281,13 @@ func (s *DNSSnooper) processDNSMessage(dns *layers.DNS) {
log.Debugf("DNSSnooper: caught DNS lookup: %s -> %v", newDomain, ips)
for ip := range ips {
if existingDomains, err := s.reverseDNSCache.Get(ip); err != nil {
s.reverseDNSCache.Set(ip, map[string]struct{}{newDomain: {}})
s.reverseDNSCache.Set(ip, domainSet{domains: map[string]struct{}{newDomain: {}}})
} else {
// TODO: Be smarter about the expiration of entries with pre-existing associated domains
existingDomains.(map[string]struct{})[newDomain] = struct{}{}
domainSet := existingDomains.(domainSet)
domainSet.Lock()
domainSet.domains[newDomain] = struct{}{}
domainSet.Unlock()
}
}
}