diff --git a/cache/stringcache/chained_grouped_cache.go b/cache/stringcache/chained_grouped_cache.go index d2deb27a1..f3f651a64 100644 --- a/cache/stringcache/chained_grouped_cache.go +++ b/cache/stringcache/chained_grouped_cache.go @@ -1,11 +1,5 @@ package stringcache -import ( - "sort" - - "golang.org/x/exp/maps" -) - type ChainedGroupedCache struct { caches []GroupedStringCache } @@ -25,20 +19,16 @@ func (c *ChainedGroupedCache) ElementCount(group string) int { return sum } -func (c *ChainedGroupedCache) Contains(searchString string, groups []string) []string { - groupMatchedMap := make(map[string]struct{}, len(groups)) +func (c *ChainedGroupedCache) Contains(searchString string, groups []string) map[string]string { + result := make(map[string]string, len(groups)) for _, cache := range c.caches { - for _, group := range cache.Contains(searchString, groups) { - groupMatchedMap[group] = struct{}{} + for group, rule := range cache.Contains(searchString, groups) { + result[group] = rule } } - matchedGroups := maps.Keys(groupMatchedMap) - - sort.Strings(matchedGroups) - - return matchedGroups + return result } func (c *ChainedGroupedCache) Refresh(group string) GroupFactory { diff --git a/cache/stringcache/chained_grouped_cache_test.go b/cache/stringcache/chained_grouped_cache_test.go index 8fe2432ae..db7c7a0c4 100644 --- a/cache/stringcache/chained_grouped_cache_test.go +++ b/cache/stringcache/chained_grouped_cache_test.go @@ -4,6 +4,7 @@ import ( "github.com/0xERR0R/blocky/cache/stringcache" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "golang.org/x/exp/maps" ) var _ = Describe("Chained grouped cache", func() { @@ -58,8 +59,8 @@ var _ = Describe("Chained grouped cache", func() { It("should find strings", func() { factory.Finish() - Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1")) - Expect(cache.Contains("string2", []string{"group1", "someOtherGroup"})).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string1", []string{"group1"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string2", []string{"group1", "someOtherGroup"}))).Should(ConsistOf("group1")) }) }) }) @@ -86,9 +87,9 @@ var _ = Describe("Chained grouped cache", func() { It("should contain 4 elements in 2 groups", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2)) Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2)) - Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(ConsistOf("group1")) - Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2")) - Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group1", "group2")) + Expect(maps.Keys(cache.Contains("g1", []string{"group1", "group2"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("g2", []string{"group1", "group2"}))).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("both", []string{"group1", "group2"}))).Should(ConsistOf("group1", "group2")) }) It("should replace group content on refresh", func() { @@ -98,10 +99,10 @@ var _ = Describe("Chained grouped cache", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1)) Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2)) - Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(BeEmpty()) - Expect(cache.Contains("newString", []string{"group1", "group2"})).Should(ConsistOf("group1")) - Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2")) - Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("g1", []string{"group1", "group2"}))).Should(BeEmpty()) + Expect(maps.Keys(cache.Contains("newString", []string{"group1", "group2"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("g2", []string{"group1", "group2"}))).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("both", []string{"group1", "group2"}))).Should(ConsistOf("group2")) }) It("should replace empty groups on refresh", func() { diff --git a/cache/stringcache/grouped_cache_interface.go b/cache/stringcache/grouped_cache_interface.go index 53dec3c05..5e2962760 100644 --- a/cache/stringcache/grouped_cache_interface.go +++ b/cache/stringcache/grouped_cache_interface.go @@ -2,8 +2,8 @@ package stringcache type GroupedStringCache interface { // Contains checks if one or more groups in the cache contains the search string. - // Returns group(s) containing the string or empty slice if string was not found - Contains(searchString string, groups []string) []string + // Returns group(s) containing the string as keys, and the rule that caused the block as value, empty map if string was not found + Contains(searchString string, groups []string) map[string]string // Refresh creates new factory for the group to be refreshed. // Calling Finish on the factory will perform the group refresh. diff --git a/cache/stringcache/in_memory_grouped_cache.go b/cache/stringcache/in_memory_grouped_cache.go index db0e01272..5c052cd87 100644 --- a/cache/stringcache/in_memory_grouped_cache.go +++ b/cache/stringcache/in_memory_grouped_cache.go @@ -43,16 +43,19 @@ func (c *InMemoryGroupedCache) ElementCount(group string) int { return cache.elementCount() } -func (c *InMemoryGroupedCache) Contains(searchString string, groups []string) []string { - var result []string +func (c *InMemoryGroupedCache) Contains(searchString string, groups []string) map[string]string { + result := make(map[string]string, len(groups)) for _, group := range groups { c.lock.RLock() cache, found := c.caches[group] c.lock.RUnlock() - if found && cache.contains(searchString) { - result = append(result, group) + if found { + match, rule := cache.contains(searchString) + if match { + result[group] = rule + } } } diff --git a/cache/stringcache/in_memory_grouped_cache_test.go b/cache/stringcache/in_memory_grouped_cache_test.go index 938d7b8a4..a58563707 100644 --- a/cache/stringcache/in_memory_grouped_cache_test.go +++ b/cache/stringcache/in_memory_grouped_cache_test.go @@ -4,6 +4,7 @@ import ( "github.com/0xERR0R/blocky/cache/stringcache" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "golang.org/x/exp/maps" ) var _ = Describe("In-Memory grouped cache", func() { @@ -65,8 +66,8 @@ var _ = Describe("In-Memory grouped cache", func() { It("should find strings", func() { factory.Finish() - Expect(cache.Contains("string1", []string{"group1"})).Should(ConsistOf("group1")) - Expect(cache.Contains("string2", []string{"group1", "someOtherGroup"})).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string1", []string{"group1"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string2", []string{"group1", "someOtherGroup"}))).Should(ConsistOf("group1")) }) }) When("Regex grouped cache is used", func() { @@ -81,9 +82,9 @@ var _ = Describe("In-Memory grouped cache", func() { It("should ignore non-regex", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1)) - Expect(cache.Contains("string1", []string{"group1"})).Should(BeEmpty()) - Expect(cache.Contains("string2", []string{"group1"})).Should(ConsistOf("group1")) - Expect(cache.Contains("shouldalsomatchstring2", []string{"group1"})).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string1", []string{"group1"}))).Should(BeEmpty()) + Expect(maps.Keys(cache.Contains("string2", []string{"group1"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("shouldalsomatchstring2", []string{"group1"}))).Should(ConsistOf("group1")) }) }) When("Wildcard grouped cache is used", func() { @@ -99,10 +100,10 @@ var _ = Describe("In-Memory grouped cache", func() { It("should ignore non-wildcard", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1)) - Expect(cache.Contains("string1", []string{"group1"})).Should(BeEmpty()) - Expect(cache.Contains("string2", []string{"group1"})).Should(BeEmpty()) - Expect(cache.Contains("string3", []string{"group1"})).Should(ConsistOf("group1")) - Expect(cache.Contains("shouldalsomatch.string3", []string{"group1"})).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("string1", []string{"group1"}))).Should(BeEmpty()) + Expect(maps.Keys(cache.Contains("string2", []string{"group1"}))).Should(BeEmpty()) + Expect(maps.Keys(cache.Contains("string3", []string{"group1"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("shouldalsomatch.string3", []string{"group1"}))).Should(ConsistOf("group1")) }) }) }) @@ -126,9 +127,9 @@ var _ = Describe("In-Memory grouped cache", func() { It("should contain 4 elements in 2 groups", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 2)) Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2)) - Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(ConsistOf("group1")) - Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2")) - Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group1", "group2")) + Expect(maps.Keys(cache.Contains("g1", []string{"group1", "group2"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("g2", []string{"group1", "group2"}))).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("both", []string{"group1", "group2"}))).Should(ConsistOf("group1", "group2")) }) It("Should replace group content on refresh", func() { @@ -138,10 +139,10 @@ var _ = Describe("In-Memory grouped cache", func() { Expect(cache.ElementCount("group1")).Should(BeNumerically("==", 1)) Expect(cache.ElementCount("group2")).Should(BeNumerically("==", 2)) - Expect(cache.Contains("g1", []string{"group1", "group2"})).Should(BeEmpty()) - Expect(cache.Contains("newString", []string{"group1", "group2"})).Should(ConsistOf("group1")) - Expect(cache.Contains("g2", []string{"group1", "group2"})).Should(ConsistOf("group2")) - Expect(cache.Contains("both", []string{"group1", "group2"})).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("g1", []string{"group1", "group2"}))).Should(BeEmpty()) + Expect(maps.Keys(cache.Contains("newString", []string{"group1", "group2"}))).Should(ConsistOf("group1")) + Expect(maps.Keys(cache.Contains("g2", []string{"group1", "group2"}))).Should(ConsistOf("group2")) + Expect(maps.Keys(cache.Contains("both", []string{"group1", "group2"}))).Should(ConsistOf("group2")) }) }) }) diff --git a/cache/stringcache/string_caches.go b/cache/stringcache/string_caches.go index 001118cdf..c72f348c8 100644 --- a/cache/stringcache/string_caches.go +++ b/cache/stringcache/string_caches.go @@ -11,7 +11,7 @@ import ( type stringCache interface { elementCount() int - contains(searchString string) bool + contains(searchString string) (bool, string) } type cacheFactory interface { @@ -36,12 +36,12 @@ func (cache stringMap) elementCount() int { return count } -func (cache stringMap) contains(searchString string) bool { +func (cache stringMap) contains(searchString string) (bool, string) { normalized := normalizeEntry(searchString) searchLen := len(normalized) if searchLen == 0 { - return false + return false, "" } searchBucketLen := len(cache[searchLen]) / searchLen @@ -54,11 +54,11 @@ func (cache stringMap) contains(searchString string) bool { if blockRule == normalized { log.PrefixedLog("string_map").Debugf("block rule '%s' matched with '%s'", blockRule, searchString) - return true + return true, blockRule } } - return false + return false, "" } type stringCacheFactory struct { @@ -134,16 +134,16 @@ func (cache regexCache) elementCount() int { return len(cache) } -func (cache regexCache) contains(searchString string) bool { +func (cache regexCache) contains(searchString string) (bool, string) { for _, regex := range cache { if regex.MatchString(searchString) { log.PrefixedLog("regex_cache").Debugf("regex '%s' matched with '%s'", regex, searchString) - return true + return true, regex.String() } } - return false + return false, "" } type regexCacheFactory struct { @@ -197,8 +197,12 @@ func (cache wildcardCache) elementCount() int { return cache.cnt } -func (cache wildcardCache) contains(domain string) bool { - return cache.trie.HasParentOf(domain) +func (cache wildcardCache) contains(domain string) (bool, string) { + contains, path := cache.trie.HasParentOf(domain) + rule := strings.Join(path, ".") + rule = strings.Trim(rule, ".") + + return contains, rule } type wildcardCacheFactory struct { diff --git a/cache/stringcache/string_caches_benchmark_test.go b/cache/stringcache/string_caches_benchmark_test.go index 97d8fa639..fc0acbf5a 100644 --- a/cache/stringcache/string_caches_benchmark_test.go +++ b/cache/stringcache/string_caches_benchmark_test.go @@ -180,7 +180,7 @@ func benchmarkCache(b *testing.B, data []string, newFactory func() cacheFactory) // - wildcards and regexes need a plain string query // - all benchmarks will do the same number of queries for _, s := range stringTestData { - if !cache.contains(s) { + if contains, _ := cache.contains(s); !contains { b.Fatalf("cache is missing value from stringTestData: %s", s) } } diff --git a/cache/stringcache/string_caches_test.go b/cache/stringcache/string_caches_test.go index 2c6d81713..ddf470cae 100644 --- a/cache/stringcache/string_caches_test.go +++ b/cache/stringcache/string_caches_test.go @@ -38,17 +38,25 @@ var _ = Describe("Caches", func() { }) It("should match if StringCache contains exact string", func() { - Expect(cache.contains("apple.com")).Should(BeTrue()) - Expect(cache.contains("google.com")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeFalse()) - Expect(cache.contains("")).Should(BeFalse()) + contains, _ := cache.contains("apple.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("google.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("www.google.com") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("") + Expect(contains).Should(BeFalse()) }) It("should match case-insensitive", func() { - Expect(cache.contains("aPPle.com")).Should(BeTrue()) - Expect(cache.contains("google.COM")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeFalse()) - Expect(cache.contains("")).Should(BeFalse()) + contains, _ := cache.contains("aPPle.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("google.COM") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("www.google.com") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("") + Expect(contains).Should(BeFalse()) }) It("should return correct element count", func() { @@ -88,18 +96,30 @@ var _ = Describe("Caches", func() { }) It("should match if one regex in StringCache matches string", func() { - Expect(cache.contains("google.com")).Should(BeTrue()) - Expect(cache.contains("google.coma")).Should(BeTrue()) - Expect(cache.contains("agoogle.com")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeTrue()) - Expect(cache.contains("apple.com")).Should(BeTrue()) - Expect(cache.contains("apple.de")).Should(BeTrue()) - Expect(cache.contains("apple.it")).Should(BeFalse()) - Expect(cache.contains("www.apple.com")).Should(BeFalse()) - Expect(cache.contains("applecom")).Should(BeFalse()) - Expect(cache.contains("www.amazon.com")).Should(BeTrue()) - Expect(cache.contains("amazon.com")).Should(BeTrue()) - Expect(cache.contains("myamazon.com")).Should(BeTrue()) + contains, _ := cache.contains("google.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("google.coma") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("agoogle.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("www.google.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("apple.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("apple.de") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("apple.it") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("www.apple.com") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("applecom") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("www.amazon.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("amazon.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("myamazon.com") + Expect(contains).Should(BeTrue()) }) It("should return correct element count", func() { @@ -140,28 +160,42 @@ var _ = Describe("Caches", func() { It("should match if one regex in StringCache matches string", func() { // first entry - Expect(cache.contains("example.com")).Should(BeTrue()) - Expect(cache.contains("www.example.com")).Should(BeTrue()) + contains, _ := cache.contains("example.com") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("www.example.com") + Expect(contains).Should(BeTrue()) // look alikes - Expect(cache.contains("com")).Should(BeFalse()) - Expect(cache.contains("example.coma")).Should(BeFalse()) - Expect(cache.contains("an-example.com")).Should(BeFalse()) - Expect(cache.contains("examplecom")).Should(BeFalse()) + contains, _ = cache.contains("com") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("example.coma") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("an-example.com") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("examplecom") + Expect(contains).Should(BeFalse()) // other entry - Expect(cache.contains("example.org")).Should(BeTrue()) - Expect(cache.contains("www.example.org")).Should(BeTrue()) + contains, _ = cache.contains("example.org") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("www.example.org") + Expect(contains).Should(BeTrue()) // unrelated - Expect(cache.contains("example.net")).Should(BeFalse()) - Expect(cache.contains("www.example.net")).Should(BeFalse()) + contains, _ = cache.contains("example.net") + Expect(contains).Should(BeFalse()) + contains, _ = cache.contains("www.example.net") + Expect(contains).Should(BeFalse()) // third entry - Expect(cache.contains("blocked")).Should(BeTrue()) - Expect(cache.contains("sub.blocked")).Should(BeTrue()) - Expect(cache.contains("sub.sub.blocked")).Should(BeTrue()) - Expect(cache.contains("example.blocked")).Should(BeTrue()) + contains, _ = cache.contains("blocked") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("sub.blocked") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("sub.sub.blocked") + Expect(contains).Should(BeTrue()) + contains, _ = cache.contains("example.blocked") + Expect(contains).Should(BeTrue()) }) It("should return correct element count", func() { diff --git a/lists/list_cache.go b/lists/list_cache.go index 843b6bb7e..9450802d8 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -32,7 +32,7 @@ type ListCacheType int // Matcher checks if a domain is in a list type Matcher interface { // Match matches passed domain name against cached list entries - Match(domain string, groupsToCheck []string) (groups []string) + Match(domain string, groupsToCheck []string) (matches map[string]string) } // ListCache generic cache of strings divided in groups @@ -104,7 +104,8 @@ func logger() *logrus.Entry { } // Match matches passed domain name against cached list entries -func (b *ListCache) Match(domain string, groupsToCheck []string) (groups []string) { +// Returns group(s) containing the string as keys, and the rule that caused the block as value +func (b *ListCache) Match(domain string, groupsToCheck []string) (matches map[string]string) { return b.groupedCache.Contains(domain, groupsToCheck) } diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go index b937b15e9..6f8130207 100644 --- a/lists/list_cache_test.go +++ b/lists/list_cache_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "golang.org/x/exp/maps" "io" "net/http/httptest" "os" @@ -95,14 +96,14 @@ var _ = Describe("ListCache", func() { When("Query with empty", func() { It("should not panic", func() { - group := sut.Match("", []string{"gr0"}) - Expect(group).Should(BeEmpty()) + matches := sut.Match("", []string{"gr0"}) + Expect(matches).Should(BeEmpty()) }) }) It("should not match anything", func() { - group := sut.Match("google.com", []string{"gr1"}) - Expect(group).Should(BeEmpty()) + matches := sut.Match("google.com", []string{"gr1"}) + Expect(matches).Should(BeEmpty()) }) }) When("List becomes empty on refresh", func() { @@ -118,14 +119,14 @@ var _ = Describe("ListCache", func() { }) It("should delete existing elements from group cache", func(ctx context.Context) { - group := sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) err := sut.refresh(ctx) Expect(err).Should(Succeed()) - group = sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(BeEmpty()) + matches = sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(BeEmpty()) }) }) When("List has invalid lines", func() { @@ -142,11 +143,11 @@ var _ = Describe("ListCache", func() { }) It("should still other domains", func() { - group := sut.Match("inlinedomain1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("inlinedomain1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("inlinedomain2.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches = sut.Match("inlinedomain2.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) When("a temporary/transient err occurs on download", func() { @@ -166,23 +167,23 @@ var _ = Describe("ListCache", func() { It("should not delete existing elements from group cache", func(ctx context.Context) { By("Lists loaded without timeout", func() { Eventually(func(g Gomega) { - group := sut.Match("blocked1.com", []string{"gr1"}) - g.Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + g.Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }, "1s").Should(Succeed()) }) Expect(sut.refresh(ctx)).Should(HaveOccurred()) By("List couldn't be loaded due to timeout", func() { - group := sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) _ = sut.Refresh() By("List couldn't be loaded due to timeout", func() { - group := sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) }) @@ -201,15 +202,15 @@ var _ = Describe("ListCache", func() { It("should keep existing elements from group cache", func(ctx context.Context) { By("Lists loaded without err", func() { - group := sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) Expect(sut.refresh(ctx)).Should(HaveOccurred()) By("Lists from first load is kept", func() { - group := sut.Match("blocked1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) }) @@ -222,14 +223,14 @@ var _ = Describe("ListCache", func() { }) It("should download the list and match against", func() { - group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElement("gr1")) + matches = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("blocked1a.com", []string{"gr2"}) - Expect(group).Should(ContainElement("gr2")) + matches = sut.Match("blocked1a.com", []string{"gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr2")) }) }) When("Configuration has some faulty urls", func() { @@ -241,14 +242,14 @@ var _ = Describe("ListCache", func() { }) It("should download the list and match against", func() { - group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElements("gr1", "gr2")) + matches = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElements("gr1", "gr2")) - group = sut.Match("blocked1a.com", []string{"gr2"}) - Expect(group).Should(ContainElement("gr2")) + matches = sut.Match("blocked1a.com", []string{"gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr2")) }) }) When("List will be updated", func() { @@ -265,8 +266,8 @@ var _ = Describe("ListCache", func() { }) It("event should be fired and contain count of elements in downloaded lists", func() { - group := sut.Match("blocked1.com", []string{}) - Expect(group).Should(BeEmpty()) + matches := sut.Match("blocked1.com", []string{}) + Expect(matches).Should(BeEmpty()) Expect(resultCnt).Should(Equal(3)) }) }) @@ -282,14 +283,14 @@ var _ = Describe("ListCache", func() { Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(3)) Expect(sut.groupedCache.ElementCount("gr2")).Should(Equal(2)) - group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("blocked1.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) - Expect(group).Should(ContainElement("gr1")) + matches = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("blocked1a.com", []string{"gr2"}) - Expect(group).Should(ContainElement("gr2")) + matches = sut.Match("blocked1a.com", []string{"gr2"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr2")) }) }) When("group with bigger files", func() { @@ -325,11 +326,11 @@ var _ = Describe("ListCache", func() { It("should match", func() { Expect(sut.groupedCache.ElementCount("gr1")).Should(Equal(2)) - group := sut.Match("inlinedomain1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("inlinedomain1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("inlinedomain2.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches = sut.Match("inlinedomain2.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) When("Text file can't be parsed", func() { @@ -345,8 +346,8 @@ var _ = Describe("ListCache", func() { }) It("should still match already imported strings", func() { - group := sut.Match("inlinedomain1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("inlinedomain1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) When("Text file has too many errors", func() { @@ -372,8 +373,8 @@ var _ = Describe("ListCache", func() { }) It("should still parse the domain", func() { - group := sut.Match("inlinedomain1.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("inlinedomain1.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) When("inline regex content is defined", func() { @@ -384,11 +385,11 @@ var _ = Describe("ListCache", func() { }) It("should match", func() { - group := sut.Match("apple.com", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches := sut.Match("apple.com", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) - group = sut.Match("apple.de", []string{"gr1"}) - Expect(group).Should(ContainElement("gr1")) + matches = sut.Match("apple.de", []string{"gr1"}) + Expect(maps.Keys(matches)).Should(ContainElement("gr1")) }) }) }) diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go index 57cbea7d7..0aacd270c 100644 --- a/resolver/blocking_resolver.go +++ b/resolver/blocking_resolver.go @@ -371,8 +371,8 @@ func (r *BlockingResolver) handleDenylist(ctx context.Context, groupsToCheck []s domain := util.ExtractDomain(question) logger := logger.WithField("domain", domain) - if groups := r.matches(groupsToCheck, r.allowlistMatcher, domain); len(groups) > 0 { - logger.WithField("groups", groups).Debugf("domain is allowlisted") + if matches := r.matches(groupsToCheck, r.allowlistMatcher, domain); len(matches) > 0 { + logger.WithField("groups", maps.Keys(matches)).Debugf("domain is allowlisted") resp, err := r.next.Resolve(ctx, request) @@ -385,8 +385,14 @@ func (r *BlockingResolver) handleDenylist(ctx context.Context, groupsToCheck []s return true, resp, err } - if groups := r.matches(groupsToCheck, r.denylistMatcher, domain); len(groups) > 0 { - resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", strings.Join(groups, ","))) + if matches := r.matches(groupsToCheck, r.denylistMatcher, domain); len(matches) > 0 { + var entries []string + for group, rule := range matches { + entries = append(entries, strings.Join([]string{group, rule}, ":")) + } + + msg := fmt.Sprintf("BLOCKED (%s)", strings.Join(entries, ", ")) + resp, err := r.handleBlocked(logger, request, question, msg) return true, resp, err } @@ -415,11 +421,17 @@ func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) if len(entryToCheck) > 0 { logger := logger.WithField("response_entry", entryToCheck) - if groups := r.matches(groupsToCheck, r.allowlistMatcher, entryToCheck); len(groups) > 0 { - logger.WithField("groups", groups).Debugf("%s is allowlisted", tName) - } else if groups := r.matches(groupsToCheck, r.denylistMatcher, entryToCheck); len(groups) > 0 { - return r.handleBlocked(logger, request, request.Req.Question[0], fmt.Sprintf("BLOCKED %s (%s)", tName, - strings.Join(groups, ","))) + if matches := r.matches(groupsToCheck, r.allowlistMatcher, entryToCheck); len(matches) > 0 { + logger.WithField("groups", maps.Keys(matches)).Debugf("%s is allowlisted", tName) + } else if matches := r.matches(groupsToCheck, r.denylistMatcher, entryToCheck); len(matches) > 0 { + var entries []string + for group, rule := range matches { + entries = append(entries, strings.Join([]string{group, rule}, ":")) + } + + msg := fmt.Sprintf("BLOCKED %s (%s)", tName, strings.Join(entries, ", ")) + + return r.handleBlocked(logger, request, request.Req.Question[0], msg) } } } @@ -515,12 +527,12 @@ func (r *BlockingResolver) groupsToCheckForClient(request *model.Request) []stri func (r *BlockingResolver) matches(groupsToCheck []string, m lists.Matcher, domain string, -) (group []string) { +) (matches map[string]string) { if len(groupsToCheck) > 0 { return m.Match(domain, groupsToCheck) } - return []string{} + return nil } type blockHandler interface { diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index 64abc7329..eef3cb086 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -223,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -234,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -245,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -256,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("blocked2.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr2)"), + HaveReason("BLOCKED (gr2:blocked2.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -267,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", AAAA, "::"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -289,7 +289,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -329,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -340,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -354,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -365,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("blocked2.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr2)"), + HaveReason("BLOCKED (gr2:blocked2.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -378,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (gr1)"), + HaveReason("BLOCKED (gr1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -392,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -419,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveNoAnswer(), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeNameError), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) }) @@ -446,7 +446,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 1234)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) @@ -463,7 +463,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 1234)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) }) @@ -491,7 +491,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) @@ -503,7 +503,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) }) @@ -530,7 +530,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), )) }) }) @@ -549,7 +549,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED IP (defaultGroup)"), + HaveReason("BLOCKED IP (defaultGroup:123.145.123.145)"), )) }) }) @@ -569,7 +569,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED IP (defaultGroup)"), + HaveReason("BLOCKED IP (defaultGroup:2001:db8:85a3:8d3::370:7344)"), )) }) }) @@ -592,7 +592,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveTTL(BeNumerically("==", 21600)), HaveResponseType(ResponseTypeBLOCKED), HaveReturnCode(dns.RcodeSuccess), - HaveReason("BLOCKED CNAME (defaultGroup)"), + HaveReason("BLOCKED CNAME (defaultGroup:badcnamedomain.com)"), )) }) }) @@ -819,7 +819,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -830,7 +830,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -893,7 +893,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -908,7 +908,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -918,7 +918,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -934,7 +934,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Eventually(enabled, "1s").Should(Receive(BeFalse())) }) - By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { + By("perform the same query again to ensure that this query will not be blocked (defaultGroup:blocked3.com)", func() { // now is blocking disabled, query the url again Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( @@ -947,7 +947,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { m.AssertExpectations(GinkgoT()) m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1) }) - By("perform the same query again to ensure that this query will not be blocked (group1)", func() { + By("perform the same query again to ensure that this query will not be blocked (group1:domain1.com)", func() { // now is blocking disabled, query the url again Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( @@ -974,7 +974,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) @@ -983,7 +983,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -998,7 +998,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -1008,7 +1008,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) @@ -1024,18 +1024,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { Eventually(enabled, "1s").Should(Receive(BeFalse())) }) - By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { + By("perform the same query again to ensure that this query will not be blocked (defaultGroup:blocked3.com)", func() { // now is blocking disabled, query the url again Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) - By("perform the same query again to ensure that this query will not be blocked (group1)", func() { + By("perform the same query again to ensure that this query will not be blocked (group1:domain1.com)", func() { // now is blocking disabled, query the url again Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( @@ -1062,7 +1062,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (defaultGroup)"), + HaveReason("BLOCKED (defaultGroup:blocked3.com)"), HaveReturnCode(dns.RcodeSuccess), )) @@ -1071,7 +1071,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), HaveResponseType(ResponseTypeBLOCKED), - HaveReason("BLOCKED (group1)"), + HaveReason("BLOCKED (group1:domain1.com)"), HaveReturnCode(dns.RcodeSuccess), )) }) diff --git a/trie/trie.go b/trie/trie.go index 1fc315d34..1ff299027 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -2,7 +2,6 @@ package trie import ( "github.com/0xERR0R/blocky/log" - "strings" ) // Trie stores a set of strings and can quickly check @@ -36,12 +35,12 @@ func (t *Trie) Insert(key string) { t.root.insert(key, t.split) } -func (t *Trie) HasParentOf(key string) bool { +func (t *Trie) HasParentOf(key string) (bool, []string) { return t.root.hasParentOf(key, t.split) } type node interface { - hasParentOf(key string, split SplitFunc) bool + hasParentOf(key string, split SplitFunc) (bool, []string) } // We save memory by not keeping track of children of @@ -96,7 +95,8 @@ func (n *parent) insert(key string, split SplitFunc) { continue case terminal: - if child.hasParentOf(rest, split) { + hasParent, _ := child.hasParentOf(rest, split) + if hasParent { // Found a parent/"prefix" in the set return } @@ -112,17 +112,16 @@ func (n *parent) insert(key string, split SplitFunc) { } } -func (n *parent) hasParentOf(key string, split SplitFunc) bool { +func (n *parent) hasParentOf(key string, split SplitFunc) (bool, []string) { searchString := key - rule := "" + path := make([]string, 0, len(key)) // over-allocating is better than allocating in a loop for { label, rest := split(key) - rule = strings.Join([]string{label, rule}, ".") child, ok := n.children[label] if !ok { - return false + return false, nil } switch child := child.(type) { @@ -130,25 +129,28 @@ func (n *parent) hasParentOf(key string, split SplitFunc) bool { if len(rest) == 0 { // The trie only contains children/"suffixes" of the // key we're searching for - return false + return false, nil } // Continue down the trie key = rest n = child + path = append(path, label) + continue case terminal: // Continue down the trie - matched := child.hasParentOf(rest, split) + matched, _ := child.hasParentOf(rest, split) if matched { - rule = strings.Join([]string{child.String(), rule}, ".") - rule = strings.Trim(rule, ".") - log.PrefixedLog("trie").Debugf("wildcard block rule '%s' matched with '%s'", rule, searchString) + path = append(path, child.String()) + log.PrefixedLog("trie").Debugf("wildcard path '%v' matched with '%s'", path, searchString) + + return matched, path } - return matched + return false, nil } } } @@ -159,10 +161,10 @@ func (t terminal) String() string { return string(t) } -func (t terminal) hasParentOf(searchKey string, split SplitFunc) bool { +func (t terminal) hasParentOf(searchKey string, split SplitFunc) (bool, []string) { tKey := t.String() if tKey == "" { - return true + return true, nil } for { @@ -170,18 +172,18 @@ func (t terminal) hasParentOf(searchKey string, split SplitFunc) bool { searchLabel, searchRest := split(searchKey) if searchLabel != tLabel { - return false + return false, nil } if len(tRest) == 0 { // Found a parent/"prefix" in the set - return true + return true, nil } if len(searchRest) == 0 { // The trie only contains children/"suffixes" of the // key we're searching for - return false + return false, nil } // Continue down the trie diff --git a/trie/trie_test.go b/trie/trie_test.go index 3a2ce085f..2c2412ce6 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -37,13 +37,16 @@ var _ = Describe("Trie", func() { ) BeforeEach(func() { - Expect(sut.HasParentOf(domainOk)).Should(BeFalse()) + hasParent, _ := sut.HasParentOf(domainOk) + Expect(hasParent).Should(BeFalse()) sut.Insert(domainOk) - Expect(sut.HasParentOf(domainOk)).Should(BeTrue()) + hasParent, _ = sut.HasParentOf(domainOk) + Expect(hasParent).Should(BeTrue()) }) AfterEach(func() { - Expect(sut.HasParentOf(domainOk)).Should(BeTrue()) + hasParent, _ := sut.HasParentOf(domainOk) + Expect(hasParent).Should(BeTrue()) }) It("should be found", func() {}) @@ -51,41 +54,49 @@ var _ = Describe("Trie", func() { It("should contain subdomains", func() { subdomain := "www." + domainOk - Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) + hasParent, _ := sut.HasParentOf(subdomain) + Expect(hasParent).Should(BeTrue()) }) It("should support inserting subdomains", func() { subdomain := "www." + domainOk - Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) + hasParent, _ := sut.HasParentOf(subdomain) + Expect(hasParent).Should(BeTrue()) sut.Insert(subdomain) - Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) + hasParent, _ = sut.HasParentOf(subdomain) + Expect(hasParent).Should(BeTrue()) }) It("should not find unrelated", func() { - Expect(sut.HasParentOf(domainKo)).Should(BeFalse()) + hasParent, _ := sut.HasParentOf(domainKo) + Expect(hasParent).Should(BeFalse()) }) It("should not find uninserted parent", func() { - Expect(sut.HasParentOf(domainOkTLD)).Should(BeFalse()) + hasParent, _ := sut.HasParentOf(domainOkTLD) + Expect(hasParent).Should(BeFalse()) }) It("should not find deep uninserted parent", func() { sut.Insert("sub.sub.sub.test") - Expect(sut.HasParentOf("sub.sub.test")).Should(BeFalse()) + hasParent, _ := sut.HasParentOf("sub.sub.test") + Expect(hasParent).Should(BeFalse()) }) It("should find inserted parent", func() { sut.Insert(domainOkTLD) - Expect(sut.HasParentOf(domainOkTLD)).Should(BeTrue()) + hasParent, _ := sut.HasParentOf(domainOkTLD) + Expect(hasParent).Should(BeTrue()) }) It("should insert sibling", func() { sibling := "other." + domainOkTLD sut.Insert(sibling) - Expect(sut.HasParentOf(sibling)).Should(BeTrue()) + hasParent, _ := sut.HasParentOf(sibling) + Expect(hasParent).Should(BeTrue()) }) It("should insert grand-children siblings", func() { @@ -94,14 +105,20 @@ var _ = Describe("Trie", func() { xyzSub := "xyz." + base sut.Insert(abcSub) - Expect(sut.HasParentOf(abcSub)).Should(BeTrue()) - Expect(sut.HasParentOf(xyzSub)).Should(BeFalse()) - Expect(sut.HasParentOf(base)).Should(BeFalse()) + hasParent, _ := sut.HasParentOf(abcSub) + Expect(hasParent).Should(BeTrue()) + hasParent, _ = sut.HasParentOf(xyzSub) + Expect(hasParent).Should(BeFalse()) + hasParent, _ = sut.HasParentOf(base) + Expect(hasParent).Should(BeFalse()) sut.Insert(xyzSub) - Expect(sut.HasParentOf(xyzSub)).Should(BeTrue()) - Expect(sut.HasParentOf(abcSub)).Should(BeTrue()) - Expect(sut.HasParentOf(base)).Should(BeFalse()) + hasParent, _ = sut.HasParentOf(xyzSub) + Expect(hasParent).Should(BeTrue()) + hasParent, _ = sut.HasParentOf(abcSub) + Expect(hasParent).Should(BeTrue()) + hasParent, _ = sut.HasParentOf(base) + Expect(hasParent).Should(BeFalse()) }) }) })