Skip to content

Commit

Permalink
feat(resolver): race UDP and TCP when connecting upstream
Browse files Browse the repository at this point in the history
Inspired by https://en.wikipedia.org/wiki/Happy_Eyeballs this should
improve latency and fixes the long standing behavior where a single
resolve attempt could take 2x the timeout.
UpstreamResolver.Resolve can still take more than the configured timeout
so maybe that can be improved by splitting the retry algorithm into its
own resolver type.
  • Loading branch information
ThinkChaos committed Dec 19, 2023
1 parent 343d38c commit df8c373
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
71 changes: 59 additions & 12 deletions resolver/upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,71 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
func (r *dnsUpstreamClient) callExternal(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) {
if protocol == model.RequestProtocolTCP {
response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
if err != nil && r.udpClient != nil {
// try UDP as fallback
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" {
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
}
if r.udpClient == nil {
return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
}

return r.raceClients(ctx, msg, upstreamURL, protocol)
}

func (r *dnsUpstreamClient) raceClients(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) {
type result struct {
proto model.RequestProtocol
msg *dns.Msg
rtt time.Duration
err error
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

// We don't explicitly close the channel, but since the buffer is big enough for all goroutines,
// it will be GC'ed and closed automatically.
ch := make(chan result, 2) //nolint:gomnd // TCP and UDP

exchange := func(client *dns.Client, proto model.RequestProtocol) {
msg, rtt, err := client.ExchangeContext(ctx, msg, upstreamURL)

ch <- result{proto, msg, rtt, err}
}

go exchange(r.tcpClient, model.RequestProtocolTCP)
go exchange(r.udpClient, model.RequestProtocolUDP)

// We don't care about a response too big for the downstream protocol: that's handled by `Server`,
// and returning a larger request from here might allow us to cache it.

res1 := <-ch
if res1.err == nil && !res1.msg.Truncated {
return res1.msg, res1.rtt, nil
}

res2 := <-ch
if res2.err == nil && !res2.msg.Truncated {
return res2.msg, res2.rtt, nil
}

resWhere := func(pred func(*result) bool) *result {
if pred(&res1) {
return &res1
}

return response, rtt, err
return &res2
}

if r.udpClient != nil {
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
// When both failed, return the result that used the same protocol as the downstream request
if res1.err != nil && res2.err != nil {
sameProto := resWhere(func(r *result) bool { return r.proto == protocol })

return sameProto.msg, sameProto.rtt, sameProto.err
}

return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
// Only a single one failed, use the one that succeeded
successful := resWhere(func(r *result) bool { return r.err == nil })

return successful.msg, successful.rtt, nil
}

// NewUpstreamResolver creates new resolver instance
Expand Down
2 changes: 1 addition & 1 deletion resolver/upstream_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
sutConfig.Upstream = mockUpstream.Start()
})

It("should retry with UDP", func() {
It("should also try with UDP", func() {
req := newRequest("example.com.", A)
req.Protocol = RequestProtocolTCP

Expand Down
8 changes: 4 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired

// truncate if necessary
response.Res.Truncate(getMaxResponseSize(w.LocalAddr().Network(), request))
response.Res.Truncate(getMaxResponseSize(r))

// enable compression
response.Res.Compress = true
Expand All @@ -624,13 +624,13 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
}

// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
func getMaxResponseSize(network string, request *dns.Msg) int {
edns := request.IsEdns0()
func getMaxResponseSize(req *model.Request) int {
edns := req.Req.IsEdns0()
if edns != nil && edns.UDPSize() > 0 {
return int(edns.UDPSize())
}

if network == "tcp" {
if req.Protocol == model.RequestProtocolTCP {
return dns.MaxMsgSize
}

Expand Down

0 comments on commit df8c373

Please sign in to comment.