From 11543356b6b303ffa18613992cbf48aefc269f8f Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 3 Dec 2023 20:29:31 +0100 Subject: [PATCH] Bugfix in ECS forward (#1290) * fixed override bug in forward * set prettier as default formatter for yaml * added ecs to example config --- .devcontainer/devcontainer.json | 2 +- docs/config.yml | 7 ++ resolver/ecs_resolver.go | 55 +++++++++------ resolver/ecs_resolver_test.go | 115 +++++++++++++++++++++++++++----- 4 files changed, 141 insertions(+), 38 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 86d175a70..fc4be3c9c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -50,7 +50,7 @@ "[go]": { "editor.defaultFormatter": "golang.go" }, - "[json][jsonc][github-actions-workflow]": { + "[yaml][json][jsonc][github-actions-workflow]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[markdown]": { diff --git a/docs/config.yml b/docs/config.yml index 9c4151623..1e21e0248 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -332,3 +332,10 @@ specialUseDomains: # optional: block recomended private TLDs # default: true rfc6762-appendixG: true + +# optional: configure extended client subnet (ECS) support +ecs: + # optional: if the request ecs option with a max sice mask the address will be used as client ip + useAsClient: true + # optional: if the request contains a ecs option it will be forwarded to the upstream resolver + forward: true diff --git a/resolver/ecs_resolver.go b/resolver/ecs_resolver.go index 049db114d..1b97eab69 100644 --- a/resolver/ecs_resolver.go +++ b/resolver/ecs_resolver.go @@ -53,17 +53,19 @@ func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*mod // Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set if r.cfg.UseAsClient && so != nil && ((so.Family == ecsFamilyIPv4 && so.SourceNetmask == ecsMaskIPv4) || (so.Family == ecsFamilyIPv6 && so.SourceNetmask == ecsMaskIPv6)) { + request.Log.Debugf("using request's edns0 address as internal client IP: %s", so.Address) request.ClientIP = so.Address } // Set the Edns0 subnet option if the client IP is IPv4 or IPv6 and the masks are set in the configuration if r.cfg.IPv4Mask > 0 || r.cfg.IPv6Mask > 0 { - r.setSubnet(request) + r.setSubnet(so, request) } // Remove the Edns0 subnet option if the client IP is IPv4 or IPv6 and the corresponding mask is not set // and the forwardEcs option is not enabled if r.cfg.IPv4Mask == 0 && r.cfg.IPv6Mask == 0 && so != nil && !r.cfg.Forward { + request.Log.Debug("remove edns0 subnet option") util.RemoveEdns0Option[*dns.EDNS0_SUBNET](request.Req) } } @@ -73,28 +75,30 @@ func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*mod // setSubnet appends the subnet information to the request as EDNS0 option // if the client IP is IPv4 or IPv6 and the corresponding mask is set in the configuration -func (r *ECSResolver) setSubnet(request *model.Request) { - e := new(dns.EDNS0_SUBNET) - e.Code = dns.EDNS0SUBNET - e.SourceScope = ecsSourceScope - - if ip := request.ClientIP.To4(); ip != nil && r.cfg.IPv4Mask > 0 { - mip, err := maskIP(ip, r.cfg.IPv4Mask) - if err == nil { - e.Family = ecsFamilyIPv4 - e.SourceNetmask = uint8(r.cfg.IPv4Mask) - e.Address = mip - util.SetEdns0Option(request.Req, e) +func (r *ECSResolver) setSubnet(so *dns.EDNS0_SUBNET, request *model.Request) { + var subIP net.IP + if so != nil && r.cfg.Forward && so.Address != nil { + subIP = so.Address + } else { + subIP = request.ClientIP + } + + var edsOption *dns.EDNS0_SUBNET + + if ip := subIP.To4(); ip != nil && r.cfg.IPv4Mask > 0 { + if mip, err := maskIP(ip, r.cfg.IPv4Mask); err == nil { + edsOption = newEdnsSubnetOption(mip, ecsFamilyIPv4, r.cfg.IPv4Mask) } - } else if ip := request.ClientIP.To16(); ip != nil && r.cfg.IPv6Mask > 0 { - mip, err := maskIP(ip, r.cfg.IPv6Mask) - if err == nil { - e.Family = ecsFamilyIPv6 - e.SourceNetmask = uint8(r.cfg.IPv6Mask) - e.Address = mip - util.SetEdns0Option(request.Req, e) + } else if ip := subIP.To16(); ip != nil && r.cfg.IPv6Mask > 0 { + if mip, err := maskIP(ip, r.cfg.IPv6Mask); err == nil { + edsOption = newEdnsSubnetOption(mip, ecsFamilyIPv6, r.cfg.IPv6Mask) } } + + if edsOption != nil { + request.Log.Debugf("set edns0 subnet option address: %s", edsOption.Address) + util.SetEdns0Option(request.Req, edsOption) + } } // maskIP masks the IP with the given mask and return an error if the mask is invalid @@ -103,3 +107,14 @@ func maskIP[maskType ECSMask](ip net.IP, mask maskType) (net.IP, error) { return mip.IP, err } + +// newEdnsSubnetOption( creates a new EDNS0 subnet option with the given IP, family and mask +func newEdnsSubnetOption[maskType ECSMask](ip net.IP, family uint16, mask maskType) *dns.EDNS0_SUBNET { + return &dns.EDNS0_SUBNET{ + Code: dns.EDNS0SUBNET, + SourceScope: ecsSourceScope, + Family: family, + SourceNetmask: uint8(mask), + Address: ip, + } +} diff --git a/resolver/ecs_resolver_test.go b/resolver/ecs_resolver_test.go index 1d947595e..d6e034ac2 100644 --- a/resolver/ecs_resolver_test.go +++ b/resolver/ecs_resolver_test.go @@ -26,9 +26,6 @@ var _ = Describe("EcsResolver", func() { err error origIP net.IP ecsIP net.IP - - ctx context.Context - cancelFn context.CancelFunc ) Describe("Type", func() { @@ -38,15 +35,12 @@ var _ = Describe("EcsResolver", func() { }) BeforeEach(func() { - ctx, cancelFn = context.WithCancel(context.Background()) - DeferCleanup(cancelFn) - err = defaults.Set(&sutConfig) Expect(err).Should(Succeed()) mockAnswer = new(dns.Msg) - origIP = net.ParseIP("1.2.3.4") - ecsIP = net.ParseIP("4.3.2.1") + origIP = net.ParseIP("1.2.3.4").To4() + ecsIP = net.ParseIP("4.3.2.1").To4() }) JustBeforeEach(func() { @@ -63,7 +57,7 @@ var _ = Describe("EcsResolver", func() { sut.Next(m) }) - When("ecs is disabled", func() { + When("ECS is disabled", func() { Describe("IsEnabled", func() { It("is false", func() { Expect(sut.IsEnabled()).Should(BeFalse()) @@ -71,7 +65,7 @@ var _ = Describe("EcsResolver", func() { }) }) - When("ecs is enabled", func() { + When("ECS is enabled", func() { BeforeEach(func() { sutConfig.UseAsClient = true }) @@ -82,12 +76,12 @@ var _ = Describe("EcsResolver", func() { }) }) - When("use ecs client ip is enabled", func() { + When("use ECS client ip is enabled", func() { BeforeEach(func() { sutConfig.UseAsClient = true }) - It("should change ClientIP with subnet 32", func() { + It("should change ClientIP with subnet 32", func(ctx context.Context) { request := newRequest("example.com.", A) request.ClientIP = origIP @@ -108,7 +102,7 @@ var _ = Describe("EcsResolver", func() { HaveReason("Test"))) }) - It("shouldn't change ClientIP with subnet 24", func() { + It("shouldn't change ClientIP with subnet 24", func(ctx context.Context) { request := newRequest("example.com.", A) request.ClientIP = origIP @@ -130,14 +124,13 @@ var _ = Describe("EcsResolver", func() { }) }) - When("forward ecs is enabled", func() { + When("add ECS information", func() { BeforeEach(func() { - sutConfig.Forward = true sutConfig.IPv4Mask = 32 sutConfig.IPv6Mask = 128 }) - It("should add Ecs information with subnet 32", func() { + It("should add ECS information with subnet 32", func(ctx context.Context) { request := newRequest("example.com.", A) request.ClientIP = origIP @@ -157,13 +150,101 @@ var _ = Describe("EcsResolver", func() { HaveReason("Test"))) }) - It("should add Ecs information with subnet 128", func() { + It("should add ECS information with subnet 128", func(ctx context.Context) { + request := newRequest("example.com.", AAAA) + request.ClientIP = net.ParseIP("2001:db8::68") + + m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) { + Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(ctx, request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + }) + + When("forward ECS information", func() { + BeforeEach(func() { + sutConfig.IPv4Mask = 32 + sutConfig.IPv6Mask = 128 + sutConfig.Forward = true + }) + + It("should forward ECS information with subnet 32", func(ctx context.Context) { + request := newRequest("example.com.", A) + request.ClientIP = origIP + + addEcsOption(request.Req, ecsIP, ecsMaskIPv4) + + m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) { + Expect(req.ClientIP).Should(Equal(ecsIP)) + Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + + so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req) + Expect(so.Address).Should(Equal(ecsIP)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(ctx, request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + + When("subnet mask is 24", func() { + BeforeEach(func() { + sutConfig.IPv4Mask = 24 + }) + + It("should modify ECS information", func(ctx context.Context) { + request := newRequest("example.com.", A) + request.ClientIP = origIP + + addEcsOption(request.Req, ecsIP, ecsMaskIPv4) + + m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) { + Expect(req.ClientIP).Should(Equal(ecsIP)) + Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + + so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req) + Expect(so.Address).Should(Equal(net.ParseIP("4.3.2.0").To4())) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(ctx, request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + }) + + It("should forward ECS information with subnet 128", func(ctx context.Context) { request := newRequest("example.com.", AAAA) request.ClientIP = net.ParseIP("2001:db8::68") + addEcsOption(request.Req, net.ParseIP("2001:db8::68"), 128) + m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) { Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req) + Expect(so.Address).Should(Equal(net.ParseIP("2001:db8::68"))) + return respondWith(mockAnswer), nil }