Skip to content

Commit

Permalink
Add MinDNSResolutionRate Option
Browse files Browse the repository at this point in the history
  • Loading branch information
HomayoonAlimohammadi committed Feb 4, 2024
1 parent 03e76b3 commit ffde661
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 29 deletions.
12 changes: 12 additions & 0 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type dialOptions struct {
resolvers []resolver.Builder
idleTimeout time.Duration
recvBufferPool SharedBufferPool
minDNSResolutionRate *time.Duration
}

// DialOption configures how we set up the connection.
Expand Down Expand Up @@ -711,6 +712,17 @@ func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return withRecvBufferPool(bufferPool)
}

// WithMinDNSResolutionRate sets the default minimum rate at which DNS re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
//
// Using this option overwrites the default [minResolutionRate] specified
// in the dns resolver.
func WithMinDNSResolutionRate(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.minDNSResolutionRate = &d
})

Check warning on line 723 in dialoptions.go

View check run for this annotation

Codecov / codecov/patch

dialoptions.go#L720-L723

Added lines #L720 - L723 were not covered by tests
}

func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool
Expand Down
27 changes: 19 additions & 8 deletions internal/resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ const (
txtAttribute = "grpc_config="
)

var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
var (
addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
}
}
// minResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
minResolutionRate = 30 * time.Second // this is the default value and can be changed via BuildOptions
)

var newNetResolver = func(authority string) (internal.NetResolver, error) {
if authority == "" {
Expand Down Expand Up @@ -113,6 +118,10 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
return deadResolver{}, nil
}

if opts.MinDNSResolutionRate != nil {
minResolutionRate = *opts.MinDNSResolutionRate
}

// DNS address (non-IP).
ctx, cancel := context.WithCancel(context.Background())
d := &dnsResolver{
Expand All @@ -123,6 +132,7 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
cc: cc,
rn: make(chan struct{}, 1),
disableServiceConfig: opts.DisableServiceConfig,
minResolutionRate: minResolutionRate,
}

d.resolver, err = internal.NewNetResolver(target.URL.Host)
Expand Down Expand Up @@ -167,6 +177,7 @@ type dnsResolver struct {
// replaceNetFunc (WRITE the lookup function pointers).
wg sync.WaitGroup
disableServiceConfig bool
minResolutionRate time.Duration
}

// ResolveNow invoke an immediate resolution of the target that this
Expand Down Expand Up @@ -198,10 +209,10 @@ func (d *dnsResolver) watcher() {

var waitTime time.Duration
if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving.
// Success resolving, wait for the next ResolveNow. However, also wait for
// [minResolutionRate] seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1
waitTime = internal.MinResolutionRate
waitTime = d.minResolutionRate
select {
case <-d.ctx.Done():
return
Expand Down
28 changes: 11 additions & 17 deletions internal/resolver/dns/dns_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ func overrideNetResolver(t *testing.T, r *testNetResolver) {
t.Cleanup(func() { dnsinternal.NewNetResolver = origNetResolver })
}

// Override the DNS Min Res Rate used by the resolver.
func overrideResolutionRate(t *testing.T, d time.Duration) {
origMinResRate := dnsinternal.MinResolutionRate
dnsinternal.MinResolutionRate = d
t.Cleanup(func() { dnsinternal.MinResolutionRate = origMinResRate })
}

// Override the timer used by the DNS resolver to fire after a duration of d.
func overrideTimeAfterFunc(t *testing.T, d time.Duration) {
origTimeAfter := dnsinternal.TimeAfterFunc
Expand Down Expand Up @@ -109,7 +102,7 @@ func enableSRVLookups(t *testing.T) {

// Builds a DNS resolver for target and returns a couple of channels to read the
// state and error pushed by the resolver respectively.
func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Resolver, chan resolver.State, chan error) {
func buildResolverWithTestClientConn(t *testing.T, target string, buildOptions resolver.BuildOptions) (resolver.Resolver, chan resolver.State, chan error) {
t.Helper()

b := resolver.Get("dns")
Expand All @@ -135,7 +128,7 @@ func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Reso
}

tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF, ReportErrorF: reportErrorF}
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, resolver.BuildOptions{})
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, buildOptions)
if err != nil {
t.Fatalf("Failed to build DNS resolver for target %q: %v\n", target, err)
}
Expand Down Expand Up @@ -504,7 +497,7 @@ func (s) TestDNSResolver_Basic(t *testing.T) {
txtLookupTable: test.txtLookupTable,
})
enableSRVLookups(t)
_, stateCh, _ := buildResolverWithTestClientConn(t, test.target)
_, stateCh, _ := buildResolverWithTestClientConn(t, test.target, resolver.BuildOptions{})

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
Expand Down Expand Up @@ -635,7 +628,6 @@ func (s) TestDNSResolver_ExponentialBackoff(t *testing.T) {
func (s) TestDNSResolver_ResolveNow(t *testing.T) {
const target = "foo.bar.com"

overrideResolutionRate(t, 0)
overrideTimeAfterFunc(t, 0)
tr := &testNetResolver{
hostLookupTable: map[string][]string{
Expand All @@ -647,7 +639,8 @@ func (s) TestDNSResolver_ResolveNow(t *testing.T) {
}
overrideNetResolver(t, tr)

r, stateCh, _ := buildResolverWithTestClientConn(t, target)
var minResolutionRate time.Duration = 0
r, stateCh, _ := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{MinDNSResolutionRate: &minResolutionRate})

// Verify that the first update pushed by the resolver matches expectations.
wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
Expand Down Expand Up @@ -738,9 +731,10 @@ func (s) TestIPResolver(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
overrideResolutionRate(t, 0)
overrideTimeAfterFunc(t, 2*defaultTestTimeout)
r, stateCh, _ := buildResolverWithTestClientConn(t, test.target)

var minResolutionRate time.Duration = 0
r, stateCh, _ := buildResolverWithTestClientConn(t, test.target, resolver.BuildOptions{MinDNSResolutionRate: &minResolutionRate})

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
Expand Down Expand Up @@ -943,7 +937,7 @@ func (s) TestTXTError(t *testing.T) {
// There is no entry for "ipv4.single.fake" in the txtLookupTbl
// maintained by the fake net.Resolver. So, a TXT lookup for this
// name will return an error.
_, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake")
_, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake", resolver.BuildOptions{})

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
Expand Down Expand Up @@ -1092,7 +1086,7 @@ func (s) TestRateLimitedResolve(t *testing.T) {
}
overrideNetResolver(t, tr)

r, stateCh, _ := buildResolverWithTestClientConn(t, target)
r, stateCh, _ := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{})

// Wait for the first resolution request to be done. This happens as part
// of the first iteration of the for loop in watcher().
Expand Down Expand Up @@ -1171,7 +1165,7 @@ func (s) TestReportError(t *testing.T) {
overrideNetResolver(t, &testNetResolver{})

const target = "notfoundaddress"
_, _, errorCh := buildResolverWithTestClientConn(t, target)
_, _, errorCh := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{})

// Should receive first error.
ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
Expand Down
4 changes: 0 additions & 4 deletions internal/resolver/dns/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ var (

// The following vars are overridden from tests.
var (
// MinResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionRate = 30 * time.Second

// TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is
Expand Down
5 changes: 5 additions & 0 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net"
"net/url"
"strings"
"time"

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -168,6 +169,10 @@ type BuildOptions struct {
// field. In most cases though, it is not appropriate, and this field may
// be ignored.
Dialer func(context.Context, string) (net.Conn, error)
// MinDNSResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
// Pointer was used to differentiate not-given from default value
MinDNSResolutionRate *time.Duration
}

// An Endpoint is one network endpoint, or server, which may have multiple
Expand Down
1 change: 1 addition & 0 deletions resolver_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func (ccr *ccResolverWrapper) start() error {
DialCreds: ccr.cc.dopts.copts.TransportCredentials,
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
Dialer: ccr.cc.dopts.copts.Dialer,
MinDNSResolutionRate: ccr.cc.dopts.minDNSResolutionRate,
}
var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
Expand Down

0 comments on commit ffde661

Please sign in to comment.