Skip to content

Commit

Permalink
feat(dnsovergetaddrinfo): enable collecting CNAME
Browse files Browse the repository at this point in the history
Rather than discarding the CNAME, make sure we can actually
get the CNAME inside of the main LookupHost function.

Part of ooni/probe#1516.
  • Loading branch information
bassosimone committed Aug 22, 2022
1 parent 1addcf8 commit 4cd8a79
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
38 changes: 23 additions & 15 deletions internal/netxlite/dnsovergetaddrinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (

// dnsOverGetaddrinfoTransport is a DNSTransport using getaddrinfo.
type dnsOverGetaddrinfoTransport struct {
testableTimeout time.Duration
testableLookupHost func(ctx context.Context, domain string) ([]string, error)
testableTimeout time.Duration
testableLookupANY func(ctx context.Context, domain string) ([]string, string, error)
}

var _ model.DNSTransport = &dnsOverGetaddrinfoTransport{}
Expand All @@ -27,7 +27,7 @@ func (txp *dnsOverGetaddrinfoTransport) RoundTrip(
if query.Type() != dns.TypeANY {
return nil, ErrNoDNSTransport
}
addrs, err := txp.lookup(ctx, query.Domain())
addrs, _, err := txp.lookup(ctx, query.Domain())
if err != nil {
return nil, err
}
Expand All @@ -43,30 +43,38 @@ type dnsOverGetaddrinfoResponse struct {
query model.DNSQuery
}

type dnsOverGetaddrinfoAddrsAndCNAME struct {
addrs []string
cname string
}

func (txp *dnsOverGetaddrinfoTransport) lookup(
ctx context.Context, hostname string) ([]string, error) {
ctx context.Context, hostname string) ([]string, string, error) {
// This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases
// in which such a timeout becomes too large. One such case is
// described in https://github.com/ooni/probe/issues/1726.
addrsch, errch := make(chan []string, 1), make(chan error, 1)
addrsch, errch := make(chan *dnsOverGetaddrinfoAddrsAndCNAME, 1), make(chan error, 1)
ctx, cancel := context.WithTimeout(ctx, txp.timeout())
defer cancel()
go func() {
addrs, err := txp.lookupfn()(ctx, hostname)
addrs, cname, err := txp.lookupfn()(ctx, hostname)
if err != nil {
errch <- err
return
}
addrsch <- addrs
addrsch <- &dnsOverGetaddrinfoAddrsAndCNAME{
addrs: addrs,
cname: cname,
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case addrs := <-addrsch:
return addrs, nil
return nil, "", ctx.Err()
case p := <-addrsch:
return p.addrs, p.cname, nil
case err := <-errch:
return nil, err
return nil, "", err
}
}

Expand All @@ -77,11 +85,11 @@ func (txp *dnsOverGetaddrinfoTransport) timeout() time.Duration {
return 15 * time.Second
}

func (txp *dnsOverGetaddrinfoTransport) lookupfn() func(ctx context.Context, domain string) ([]string, error) {
if txp.testableLookupHost != nil {
return txp.testableLookupHost
func (txp *dnsOverGetaddrinfoTransport) lookupfn() func(ctx context.Context, domain string) ([]string, string, error) {
if txp.testableLookupANY != nil {
return txp.testableLookupANY
}
return getaddrinfoLookupHost
return getaddrinfoLookupANY
}

func (txp *dnsOverGetaddrinfoTransport) RequiresPadding() bool {
Expand Down
20 changes: 10 additions & 10 deletions internal/netxlite/dnsovergetaddrinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestDNSOverGetaddrinfo(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("with invalid query type", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
testableLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
return []string{"8.8.8.8"}, "dns.google", nil
},
}
encoder := &DNSEncoderMiekg{}
Expand All @@ -73,8 +73,8 @@ func TestDNSOverGetaddrinfo(t *testing.T) {

t.Run("with success", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
testableLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
return []string{"8.8.8.8"}, "dns.google", nil
},
}
encoder := &DNSEncoderMiekg{}
Expand Down Expand Up @@ -122,10 +122,10 @@ func TestDNSOverGetaddrinfo(t *testing.T) {
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
testableLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
defer wg.Done()
<-done
return []string{"8.8.8.8"}, nil
return []string{"8.8.8.8"}, "dns.google", nil
},
}
encoder := &DNSEncoderMiekg{}
Expand All @@ -148,10 +148,10 @@ func TestDNSOverGetaddrinfo(t *testing.T) {
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
testableLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
defer wg.Done()
<-done
return nil, errors.New("no such host")
return nil, "", errors.New("no such host")
},
}
encoder := &DNSEncoderMiekg{}
Expand All @@ -170,8 +170,8 @@ func TestDNSOverGetaddrinfo(t *testing.T) {

t.Run("with NXDOMAIN", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, ErrOODNSNoSuchHost
testableLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
return nil, "", ErrOODNSNoSuchHost
},
}
encoder := &DNSEncoderMiekg{}
Expand Down

0 comments on commit 4cd8a79

Please sign in to comment.