Skip to content

Commit

Permalink
net: make cgo resolver work more accurately with network parameter
Browse files Browse the repository at this point in the history
Unlike the go resolver, the existing cgo resolver exchanges both DNS A
and AAAA RR queries unconditionally and causes unreasonable connection
setup latencies to applications using the cgo resolver.

This change adds new argument (`network`) in all functions through the
series of calls: from Resolver.internetAddrList to cgoLookupIPCNAME.

Benefit: no redundant DNS calls if certain IP version is used IPv4/IPv6
(no `AAAA` DNS requests if used tcp4, udp4, ip4 network. And vice
versa: no `A` DNS requests if used tcp6, udp6, ip6 network)

Fixes #25947

Change-Id: I39edbd726d82d6133fdada4d06cd90d401e7e669
Reviewed-on: https://go-review.googlesource.com/c/120215
Reviewed-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
ekalinin authored and bradfitz committed Oct 25, 2018
1 parent fc4f2e5 commit c659be4
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/net/cgo_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
return 0, nil, false
}

func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) {
return nil, nil, false
}

Expand Down
37 changes: 21 additions & 16 deletions src/net/cgo_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type reverseLookupResult struct {
}

func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) {
addrs, err, completed := cgoLookupIP(ctx, name)
addrs, err, completed := cgoLookupIP(ctx, "ip", name)
for _, addr := range addrs {
hosts = append(hosts, addr.String())
}
Expand All @@ -69,13 +69,11 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
default:
return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}, true
}
if len(network) >= 4 {
switch network[3] {
case '4':
hints.ai_family = C.AF_INET
case '6':
hints.ai_family = C.AF_INET6
}
switch ipVersion(network) {
case '4':
hints.ai_family = C.AF_INET
case '6':
hints.ai_family = C.AF_INET6
}
if ctx.Done() == nil {
port, err := cgoLookupServicePort(&hints, network, service)
Expand Down Expand Up @@ -135,13 +133,20 @@ func cgoPortLookup(result chan<- portLookupResult, hints *C.struct_addrinfo, net
result <- portLookupResult{port, err}
}

func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
func cgoLookupIPCNAME(network, name string) (addrs []IPAddr, cname string, err error) {
acquireThread()
defer releaseThread()

var hints C.struct_addrinfo
hints.ai_flags = cgoAddrInfoFlags
hints.ai_socktype = C.SOCK_STREAM
hints.ai_family = C.AF_UNSPEC
switch ipVersion(network) {
case '4':
hints.ai_family = C.AF_INET
case '6':
hints.ai_family = C.AF_INET6
}

h := make([]byte, len(name)+1)
copy(h, name)
Expand Down Expand Up @@ -197,18 +202,18 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
return addrs, cname, nil
}

func cgoIPLookup(result chan<- ipLookupResult, name string) {
addrs, cname, err := cgoLookupIPCNAME(name)
func cgoIPLookup(result chan<- ipLookupResult, network, name string) {
addrs, cname, err := cgoLookupIPCNAME(network, name)
result <- ipLookupResult{addrs, cname, err}
}

func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) {
if ctx.Done() == nil {
addrs, _, err = cgoLookupIPCNAME(name)
addrs, _, err = cgoLookupIPCNAME(network, name)
return addrs, err, true
}
result := make(chan ipLookupResult, 1)
go cgoIPLookup(result, name)
go cgoIPLookup(result, network, name)
select {
case r := <-result:
return r.addrs, r.err, true
Expand All @@ -219,11 +224,11 @@ func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, c

func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
if ctx.Done() == nil {
_, cname, err = cgoLookupIPCNAME(name)
_, cname, err = cgoLookupIPCNAME("ip", name)
return cname, err, true
}
result := make(chan ipLookupResult, 1)
go cgoIPLookup(result, name)
go cgoIPLookup(result, "ip", name)
select {
case r := <-result:
return r.cname, r.err, true
Expand Down
4 changes: 2 additions & 2 deletions src/net/cgo_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func TestCgoLookupIP(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx := context.Background()
_, err, ok := cgoLookupIP(ctx, "localhost")
_, err, ok := cgoLookupIP(ctx, "ip", "localhost")
if !ok {
t.Errorf("cgoLookupIP must not be a placeholder")
}
Expand All @@ -28,7 +28,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupIP(ctx, "localhost")
_, err, ok := cgoLookupIP(ctx, "ip", "localhost")
if !ok {
t.Errorf("cgoLookupIP must not be a placeholder")
}
Expand Down
4 changes: 2 additions & 2 deletions src/net/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func TestDialParallel(t *testing.T) {
}
}

func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
func lookupSlowFast(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
switch host {
case "slow6loopback4":
// Returns a slow IPv6 address, and a local IPv4 address.
Expand All @@ -355,7 +355,7 @@ func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPA
{IP: ParseIP("127.0.0.1")},
}, nil
default:
return fn(ctx, host)
return fn(ctx, network, host)
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/net/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestDialError(t *testing.T) {

origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
}
sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
Expand Down Expand Up @@ -293,7 +293,7 @@ func TestListenError(t *testing.T) {

origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
}
sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestListenPacketError(t *testing.T) {

origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
}

Expand Down
5 changes: 3 additions & 2 deletions src/net/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ var (
testHookHostsPath = "/etc/hosts"
testHookLookupIP = func(
ctx context.Context,
fn func(context.Context, string) ([]IPAddr, error),
fn func(context.Context, string, string) ([]IPAddr, error),
network string,
host string,
) ([]IPAddr, error) {
return fn(ctx, host)
return fn(ctx, network, host)
}
testHookSetKeepAlive = func() {}
)
10 changes: 5 additions & 5 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3825,9 +3825,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
}

// Install a fake DNS server.
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
if host != "dns-is-faked.golang" {
t.Errorf("unexpected DNS host lookup for %q", host)
t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
return nil, nil
}
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
Expand Down Expand Up @@ -4176,7 +4176,7 @@ func TestTransportMaxIdleConns(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
})

Expand Down Expand Up @@ -4416,9 +4416,9 @@ func testTransportIDNA(t *testing.T, h2 bool) {
}

// Install a fake DNS server.
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
if host != punyDomain {
t.Errorf("got DNS host lookup for %q; want %q", host, punyDomain)
t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
return nil, nil
}
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
Expand Down
2 changes: 1 addition & 1 deletion src/net/ipsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addr
}

// Try as a literal IP address, then as a DNS name.
ips, err := r.LookupIPAddr(ctx, host)
ips, err := r.lookupIPAddr(ctx, net, host)
if err != nil {
return nil, err
}
Expand Down
30 changes: 28 additions & 2 deletions src/net/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ func lookupPortMap(network, service string) (port int, error error) {
return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
}

// ipVersion returns the provided network's IP version: '4', '6' or 0
// if network does not end in a '4' or '6' byte.
func ipVersion(network string) byte {
if network == "" {
return 0
}
n := network[len(network)-1]
if n != '4' && n != '6' {
n = 0
}
return n
}

// DefaultResolver is the resolver used by the package-level Lookup
// functions and by Dialers without a specified Resolver.
var DefaultResolver = &Resolver{}
Expand Down Expand Up @@ -189,6 +202,12 @@ func LookupIP(host string) ([]IP, error) {
// LookupIPAddr looks up host using the local resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
return r.lookupIPAddr(ctx, "ip", host)
}

// lookupIPAddr looks up host using the local resolver and particular network.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
// Make sure that no matter what we do later, host=="" is rejected.
// parseIP, for example, does accept empty strings.
if host == "" {
Expand All @@ -205,7 +224,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
// can be overridden by tests. This is needed by net/http, so it
// uses a context key instead of unexported variables.
resolverFunc := r.lookupIP
if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string, string) ([]IPAddr, error)); alt != nil {
resolverFunc = alt
}

Expand All @@ -218,7 +237,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
dnsWaitGroup.Add(1)
ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) {
defer dnsWaitGroup.Done()
return testHookLookupIP(lookupGroupCtx, resolverFunc, host)
return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
})
if !called {
dnsWaitGroup.Done()
Expand Down Expand Up @@ -289,6 +308,13 @@ func LookupPort(network, service string) (port int, err error) {
func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
port, needsLookup := parsePort(service)
if needsLookup {
switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
case "": // a hint wildcard for Go 1.0 undocumented behavior
network = "ip"
default:
return 0, &AddrError{Err: "unknown network", Addr: network}
}
port, err = r.lookupPort(ctx, network, service)
if err != nil {
return 0, err
Expand Down
2 changes: 1 addition & 1 deletion src/net/lookup_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e
return nil, syscall.ENOPROTOOPT
}

func (*Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
func (*Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
return nil, syscall.ENOPROTOOPT
}

Expand Down
2 changes: 1 addition & 1 deletion src/net/lookup_plan9.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ loop:
return
}

func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) {
lits, err := r.lookupHost(ctx, host)
if err != nil {
return
Expand Down
30 changes: 28 additions & 2 deletions src/net/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ import (
"time"
)

func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
func lookupLocalhost(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
switch host {
case "localhost":
return []IPAddr{
{IP: IPv4(127, 0, 0, 1)},
{IP: IPv6loopback},
}, nil
default:
return fn(ctx, host)
return fn(ctx, network, host)
}
}

Expand Down Expand Up @@ -1008,3 +1008,29 @@ func TestConcurrentPreferGoResolversDial(t *testing.T) {
}
}
}

var ipVersionTests = []struct {
network string
version byte
}{
{"tcp", 0},
{"tcp4", '4'},
{"tcp6", '6'},
{"udp", 0},
{"udp4", '4'},
{"udp6", '6'},
{"ip", 0},
{"ip4", '4'},
{"ip6", '6'},
{"ip7", 0},
{"", 0},
}

func TestIPVersion(t *testing.T) {
for _, tt := range ipVersionTests {
if version := ipVersion(tt.network); version != tt.version {
t.Errorf("Family for: %s. Expected: %s, Got: %s", tt.network,
string(tt.version), string(version))
}
}
}
4 changes: 2 additions & 2 deletions src/net/lookup_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string,
return r.goLookupHostOrder(ctx, host, order)
}

func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
if r.preferGo() {
return r.goLookupIP(ctx, host)
}
order := systemConf().hostLookupOrder(r, host)
if order == hostLookupCgo {
if addrs, err, ok := cgoLookupIP(ctx, host); ok {
if addrs, err, ok := cgoLookupIP(ctx, network, host); ok {
return addrs, err
}
// cgo not available (or netgo); fall back to Go's DNS resolver
Expand Down
14 changes: 11 additions & 3 deletions src/net/lookup_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
}

func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
ips, err := r.lookupIP(ctx, name)
ips, err := r.lookupIP(ctx, "ip", name)
if err != nil {
return nil, err
}
Expand All @@ -76,14 +76,22 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
return addrs, nil
}

func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
// TODO(bradfitz,brainman): use ctx more. See TODO below.

var family int32 = syscall.AF_UNSPEC
switch ipVersion(network) {
case '4':
family = syscall.AF_INET
case '6':
family = syscall.AF_INET6
}

getaddr := func() ([]IPAddr, error) {
acquireThread()
defer releaseThread()
hints := syscall.AddrinfoW{
Family: syscall.AF_UNSPEC,
Family: family,
Socktype: syscall.SOCK_STREAM,
Protocol: syscall.IPPROTO_IP,
}
Expand Down
Loading

0 comments on commit c659be4

Please sign in to comment.