diff --git a/exchanger/exchanger.go b/exchanger/exchanger.go index b4c642e1..2aa582c6 100644 --- a/exchanger/exchanger.go +++ b/exchanger/exchanger.go @@ -2,7 +2,6 @@ package exchanger import ( "log" - "net" "time" "github.com/mesosphere/mesos-dns/logging" @@ -36,24 +35,6 @@ func Decorate(ex Exchanger, ds ...Decorator) Exchanger { return decorated } -// Pred is a predicate function type for dns.Msgs. -type Pred func(*dns.Msg) bool - -// While returns an Exchanger which attempts the given Exchangers while the given -// predicate function returns true for the returned dns.Msg, an error is returned, -// or all Exchangers are attempted, in which case the return values of the last -// one are returned. -func While(p Pred, exs ...Exchanger) Exchanger { - return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) { - for _, ex := range exs { - if r, rtt, err = ex.Exchange(m, a); err != nil || !p(r) { - break - } - } - return - }) -} - // ErrorLogging returns a Decorator which logs an Exchanger's errors to the given // logger. func ErrorLogging(l *log.Logger) Decorator { @@ -70,50 +51,18 @@ func ErrorLogging(l *log.Logger) Decorator { } // Instrumentation returns a Decorator which instruments an Exchanger with the given -// counter. -func Instrumentation(c logging.Counter) Decorator { - return func(ex Exchanger) Exchanger { - return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { - defer c.Inc() - return ex.Exchange(m, a) - }) - } -} - -// A Recurser returns the addr (host:port) of the next DNS server to recurse a -// Msg to. Empty returns signal that further recursion isn't possible or needed. -type Recurser func(*dns.Msg) string - -// Recurse is the default Mesos-DNS Recurser which returns an addr (host:port) -// only when the given dns.Msg doesn't contain authoritative answers and has at -// least one SOA record in its NS section. -func Recurse(r *dns.Msg) string { - if r.Authoritative && len(r.Answer) > 0 { - return "" - } - - for _, ns := range r.Ns { - if soa, ok := ns.(*dns.SOA); ok { - return net.JoinHostPort(soa.Ns, "53") - } - } - - return "" -} - -// Recursion returns a Decorator which recurses until the given Recurser returns -// an empty string or max attempts have been reached. -func Recursion(max int, rec Recurser) Decorator { +// counters. +func Instrumentation(total, success, failure logging.Counter) Decorator { return func(ex Exchanger) Exchanger { return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) { - for i := 0; i <= max; i++ { - if r, rtt, err = ex.Exchange(m, a); err != nil { - break - } else if a = rec(r); a == "" { - break + defer func() { + if total.Inc(); err != nil { + failure.Inc() + } else { + success.Inc() } - } - return r, rtt, err + }() + return ex.Exchange(m, a) }) } } diff --git a/exchanger/exchanger_test.go b/exchanger/exchanger_test.go index 36af3ebb..2ca41bb6 100644 --- a/exchanger/exchanger_test.go +++ b/exchanger/exchanger_test.go @@ -1,169 +1,63 @@ package exchanger import ( + "bytes" "errors" - "net" - "reflect" + "log" "testing" "time" - . "github.com/mesosphere/mesos-dns/dnstest" + "github.com/mesos/mesos-dns/logging" "github.com/miekg/dns" ) -func TestWhile(t *testing.T) { - for i, tt := range []struct { - pred Pred - exs []Exchanger - want exchanged - }{ - { // error - nil, - stubs(exchanged{nil, 0, errors.New("foo")}), - exchanged{nil, 0, errors.New("foo")}, - }, - { // always true predicate - func(*dns.Msg) bool { return true }, - stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}), - exchanged{nil, 1, nil}, - }, - { // nil exchangers - nil, - nil, - exchanged{nil, 0, nil}, - }, - { // empty exchangers - nil, - stubs(), - exchanged{nil, 0, nil}, - }, - { // false predicate - func(calls int) Pred { - return func(*dns.Msg) bool { - calls++ - return calls != 2 - } - }(0), - stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}, exchanged{nil, 2, nil}), - exchanged{nil, 1, nil}, - }, - } { - var got exchanged - got.m, got.rtt, got.err = While(tt.pred, tt.exs...).Exchange(nil, "") - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) +func TestErrorLogging(t *testing.T) { + { // with error + var buf bytes.Buffer + ErrorLogging(log.New(&buf, "", 0))( + stub(exchanged{err: errors.New("timeout")}), + ).Exchange(nil, "1.2.3.4") + + want := "timeout: exchanging (*dns.Msg)(nil) with \"1.2.3.4\"\n" + if got := buf.String(); got != want { + t.Errorf("got %q, want %q", got, want) } } -} + { // no error + var buf bytes.Buffer + ErrorLogging(log.New(&buf, "", 0))(stub(exchanged{})).Exchange(nil, "1.2.3.4") -func TestRecurse(t *testing.T) { - for i, tt := range []struct { - *dns.Msg - want string - }{ - { // Authoritative with answers - Message( - Header(true, 0), - Answers( - A(RRHeader("localhost", dns.TypeA, 0), net.IPv6loopback.To4()), - ), - NSs( - SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0), - ), - ), - "", - }, - { // Authoritative, empty answers, no SOA records - Message( - Header(true, 0), - NSs( - NS(RRHeader("", dns.TypeNS, 0), "next"), - ), - ), - "", - }, - { // Not authoritative, no SOA record - Message(Header(false, 0)), - "", - }, - { // Not authoritative, one SOA record - Message( - Header(false, 0), - NSs(SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0)), - ), - "next:53", - }, - { // Authoritative, empty answers, one SOA record - Message( - Header(true, 0), - NSs( - NS(RRHeader("", dns.TypeNS, 0), "foo"), - SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0), - ), - ), - "next:53", - }, - } { - if got := Recurse(tt.Msg); got != tt.want { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) + if got, want := buf.String(), ""; got != want { + t.Errorf("got %q, want %q", got, want) } } } -func TestRecursion(t *testing.T) { - for i, tt := range []struct { - max int - rec Recurser - ex Exchanger - want exchanged - }{ - { - 0, - func(*dns.Msg) string { return "next" }, - seq(stubs(exchanged{rtt: 1})...), - exchanged{rtt: 1}, - }, - { - 1, - func(*dns.Msg) string { return "next" }, - seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...), - exchanged{rtt: 1}, - }, - { - 0, - nil, - seq(stubs(exchanged{err: errors.New("foo")})...), - exchanged{err: errors.New("foo")}, - }, - { - 2, - func(calls int) Recurser { - return func(*dns.Msg) string { - if calls++; calls <= 1 { - return "next" - } - return "" - } - }(0), - seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...), - exchanged{rtt: 1}, - }, - } { - var got exchanged - got.m, got.rtt, got.err = Recursion(tt.max, tt.rec)(tt.ex).Exchange(nil, "") - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want) +func TestInstrumentation(t *testing.T) { + { // with error + var total, success, failure logging.LogCounter + Instrumentation(&total, &success, &failure)( + stub(exchanged{err: errors.New("timeout")}), + ).Exchange(nil, "1.2.3.4") + + want := []string{"1", "0", "1"} + for i, c := range []*logging.LogCounter{&total, &success, &failure} { + if got, want := c.String(), want[i]; got != want { + t.Errorf("test #%d: got %q, want %q", i, got, want) + } } } -} + { // no error + var total, success, failure logging.LogCounter + Instrumentation(&total, &success, &failure)(stub(exchanged{})).Exchange(nil, "1.2.3.4") -func seq(exs ...Exchanger) Exchanger { - var i int - return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { - ex := exs[i] - i++ - return ex.Exchange(m, a) - }) + want := []string{"1", "1", "0"} + for i, c := range []*logging.LogCounter{&total, &success, &failure} { + if got, want := c.String(), want[i]; got != want { + t.Errorf("test #%d: got %q, want %q", i, got, want) + } + } + } } func stubs(ed ...exchanged) []Exchanger { diff --git a/exchanger/forwarder.go b/exchanger/forwarder.go new file mode 100644 index 00000000..8cc0dc72 --- /dev/null +++ b/exchanger/forwarder.go @@ -0,0 +1,49 @@ +package exchanger + +import ( + "fmt" + "net" + + "github.com/miekg/dns" +) + +// A Forwarder is a DNS message forwarder that transparently proxies messages +// to DNS servers. +type Forwarder func(*dns.Msg, string) (*dns.Msg, error) + +// Forward is an utility method that calls f itself. +func (f Forwarder) Forward(m *dns.Msg, proto string) (*dns.Msg, error) { + return f(m, proto) +} + +// NewForwarder returns a new Forwarder for the given addrs with the given +// Exchangers map which maps network protocols to Exchangers. +// +// Every message will be exchanged with each address until no error is returned. +// If no addresses or no matching protocol exchanger exist, a *ForwardError will +// be returned. +func NewForwarder(addrs []string, exs map[string]Exchanger) Forwarder { + return func(m *dns.Msg, proto string) (r *dns.Msg, err error) { + ex, ok := exs[proto] + if !ok || len(addrs) == 0 { + return nil, &ForwardError{Addrs: addrs, Proto: proto} + } + for _, a := range addrs { + if r, _, err = ex.Exchange(m, net.JoinHostPort(a, "53")); err == nil { + break + } + } + return + } +} + +// A ForwardError is returned by Forwarders when they can't forward. +type ForwardError struct { + Addrs []string + Proto string +} + +// Error implements the error interface. +func (e ForwardError) Error() string { + return fmt.Sprintf("can't forward to %v over %q", e.Addrs, e.Proto) +} diff --git a/exchanger/forwarder_test.go b/exchanger/forwarder_test.go new file mode 100644 index 00000000..9c3c49f0 --- /dev/null +++ b/exchanger/forwarder_test.go @@ -0,0 +1,95 @@ +package exchanger + +import ( + "errors" + "reflect" + "testing" + "time" + + "github.com/kylelemons/godebug/pretty" + . "github.com/mesosphere/mesos-dns/dnstest" + "github.com/miekg/dns" +) + +func TestForwarder(t *testing.T) { + exs := func(e exchanged, protos ...string) map[string]Exchanger { + es := make(map[string]Exchanger, len(protos)) + for _, proto := range protos { + es[proto] = stub(e) + } + return es + } + + msg := Message(Question("foo.bar", dns.TypeA)) + for i, tt := range []struct { + addrs []string + exs map[string]Exchanger + proto string + r *dns.Msg + err error + }{ + { // no matching protocol + nil, exs(exchanged{}, "udp"), "tcp", nil, &ForwardError{nil, "tcp"}, + }, + { // matching protocol, no addrs + nil, exs(exchanged{}, "udp"), "udp", nil, &ForwardError{nil, "udp"}, + }, + { // matching protocol, no addrs + []string{}, exs(exchanged{}, "udp"), "udp", nil, &ForwardError{[]string{}, "udp"}, + }, + { // matching protocol, one addr, no error exchanging + addrs: []string{"1.2.3.4"}, + exs: exs(exchanged{m: msg}, "udp"), + proto: "udp", + r: msg, + }, + { // matching protocol, one addr, error exchanging + addrs: []string{"1.2.3.4"}, + exs: exs(exchanged{err: errors.New("timeout")}, "udp"), + proto: "udp", + err: errors.New("timeout"), + }, + { // matching protocol, two addrs, error exchanging with the first only + addrs: []string{"1.2.3.4", "2.3.4.5"}, + exs: map[string]Exchanger{ + "udp": Func(func(_ *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + switch a { + case "1.2.3.4": + return nil, 0, errors.New("timeout") + default: + return msg, 0, nil + } + }), + }, + proto: "udp", + r: msg, + }, + { // matching protocol, two addrs, error exchanging with all of them + addrs: []string{"1.2.3.4", "2.3.4.5"}, + exs: map[string]Exchanger{ + "udp": Func(func(_ *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + switch a { + case "1.2.3.4": + return nil, 0, errors.New("timeout") + default: + return nil, 0, errors.New("eof") + } + }), + }, + proto: "udp", + err: errors.New("eof"), + }, + } { + var got forwarded + got.r, got.err = NewForwarder(tt.addrs, tt.exs).Forward(nil, tt.proto) + if want := (forwarded{r: tt.r, err: tt.err}); !reflect.DeepEqual(got, want) { + t.Logf("test #%d\n", i) + t.Error(pretty.Compare(got, want)) + } + } +} + +type forwarded struct { + r *dns.Msg + err error +} diff --git a/logging/logging.go b/logging/logging.go index eabdd848..0abd8f07 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -46,28 +46,28 @@ func (lc *LogCounter) String() string { // LogOut holds metrics captured in an instrumented runtime. type LogOut struct { - MesosRequests Counter - MesosSuccess Counter - MesosNXDomain Counter - MesosFailed Counter - NonMesosRequests Counter - NonMesosSuccess Counter - NonMesosNXDomain Counter - NonMesosFailed Counter - NonMesosRecursed Counter + MesosRequests Counter + MesosSuccess Counter + MesosNXDomain Counter + MesosFailed Counter + NonMesosRequests Counter + NonMesosSuccess Counter + NonMesosNXDomain Counter + NonMesosFailed Counter + NonMesosForwarded Counter } // CurLog is the default package level LogOut. var CurLog = LogOut{ - MesosRequests: &LogCounter{}, - MesosSuccess: &LogCounter{}, - MesosNXDomain: &LogCounter{}, - MesosFailed: &LogCounter{}, - NonMesosRequests: &LogCounter{}, - NonMesosSuccess: &LogCounter{}, - NonMesosNXDomain: &LogCounter{}, - NonMesosFailed: &LogCounter{}, - NonMesosRecursed: &LogCounter{}, + MesosRequests: &LogCounter{}, + MesosSuccess: &LogCounter{}, + MesosNXDomain: &LogCounter{}, + MesosFailed: &LogCounter{}, + NonMesosRequests: &LogCounter{}, + NonMesosSuccess: &LogCounter{}, + NonMesosNXDomain: &LogCounter{}, + NonMesosFailed: &LogCounter{}, + NonMesosForwarded: &LogCounter{}, } // PrintCurLog prints out the current LogOut and then resets diff --git a/resolver/resolver.go b/resolver/resolver.go index 2e861be7..d66ede90 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -29,9 +29,7 @@ type Resolver struct { rs *records.RecordGenerator rsLock sync.RWMutex rng *rand.Rand - - // pluggable external DNS resolution, mainly for unit testing - extResolver exchanger.Exchanger + exchanger.Forwarder } // New returns a Resolver with the given version and configuration. @@ -44,40 +42,41 @@ func New(version string, config records.Config) *Resolver { masters: append([]string{""}, config.Masters...), } - if !config.ExternalOn { - return r - } - timeout := 5 * time.Second if config.Timeout != 0 { timeout = time.Duration(config.Timeout) * time.Second } - r.extResolver = newClient(timeout) + rs := config.Resolvers + if !config.ExternalOn { + rs = rs[:0] + } + r.Forwarder = exchanger.NewForwarder(rs, exchangers(timeout, "udp", "tcp")) return r } -func newClient(timeout time.Duration) exchanger.Exchanger { - clients := make([]exchanger.Exchanger, 2) - for i, proto := range [...]string{"udp", "tcp"} { // See RFC5966 - clients[i] = &dns.Client{ - Net: proto, - DialTimeout: timeout, - ReadTimeout: timeout, - WriteTimeout: timeout, - } +func exchangers(timeout time.Duration, protos ...string) map[string]exchanger.Exchanger { + exs := make(map[string]exchanger.Exchanger, len(protos)) + for _, proto := range protos { + exs[proto] = exchanger.Decorate( + &dns.Client{ + Net: proto, + DialTimeout: timeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + }, + exchanger.ErrorLogging(logging.Error), + exchanger.Instrumentation( + logging.CurLog.NonMesosForwarded, + logging.CurLog.NonMesosSuccess, + logging.CurLog.NonMesosFailed, + ), + ) } - return exchanger.Decorate( - exchanger.While(truncated, clients...), - exchanger.Recursion(3, exchanger.Recurse), - exchanger.ErrorLogging(logging.Error), - exchanger.Instrumentation(logging.CurLog.NonMesosRecursed), - ) + return exs } -func truncated(m *dns.Msg) bool { return m.Truncated } - // return the current (read-only) record set. attempts to write to the returned // object will likely result in a data race. func (res *Resolver) records() *records.RecordGenerator { @@ -246,52 +245,28 @@ func shuffleAnswers(rng *rand.Rand, answers []dns.RR) []dns.RR { return answers } -// HandleNonMesos handles non-mesos queries by recursing to a configured -// external resolver. +// HandleNonMesos handles non-mesos queries by forwarding to configured +// external DNS servers. func (res *Resolver) HandleNonMesos(w dns.ResponseWriter, r *dns.Msg) { - var err error - var m *dns.Msg - - // tracing info logging.CurLog.NonMesosRequests.Inc() - - // If external request are disabled - if res.extResolver == nil { - m = new(dns.Msg) - // set refused - m.SetRcode(r, 5) - } else { - for _, resolver := range res.config.Resolvers { - nameserver := net.JoinHostPort(resolver, "53") - m, _, err = res.extResolver.Exchange(r, nameserver) - if err == nil { - break - } - } - } - - // extResolver returns nil Msg sometimes cause of perf - if m == nil { - m = new(dns.Msg) - m.SetRcode(r, 2) - err = fmt.Errorf("failed external DNS lookup of %q: %v", r.Question[0].Name, err) - } + m, err := res.Forward(r, w.RemoteAddr().Network()) if err != nil { - logging.Error.Println(r.Question[0].Name) - logging.Error.Println(err) - logging.CurLog.NonMesosFailed.Inc() - } else { - // nxdomain - if len(m.Answer) == 0 { - logging.CurLog.NonMesosNXDomain.Inc() - } else { - logging.CurLog.NonMesosSuccess.Inc() - } + m = new(dns.Msg).SetRcode(r, rcode(err)) + } else if len(m.Answer) == 0 { + logging.CurLog.NonMesosNXDomain.Inc() } - reply(w, m) } +func rcode(err error) int { + switch err.(type) { + case *exchanger.ForwardError: + return dns.RcodeRefused + default: + return dns.RcodeServerFailure + } +} + // HandleMesos is a resolver request handler that responds to a resource // question with resource answer(s) // it can handle {A, SRV, ANY} diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 0c49f424..479fa096 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -11,11 +11,9 @@ import ( "reflect" "strconv" "testing" - "time" "github.com/kylelemons/godebug/pretty" . "github.com/mesosphere/mesos-dns/dnstest" - "github.com/mesosphere/mesos-dns/exchanger" "github.com/mesosphere/mesos-dns/logging" "github.com/mesosphere/mesos-dns/records" "github.com/mesosphere/mesos-dns/records/labels" @@ -77,19 +75,19 @@ func TestShuffleAnswers(t *testing.T) { func TestHandlers(t *testing.T) { res := fakeDNS(t) - res.extResolver = exchanger.Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) { + res.Forwarder = func(m *dns.Msg, net string) (*dns.Msg, error) { rr1, err := res.formatA("google.com.", "1.1.1.1") if err != nil { - return nil, 0, err + return nil, err } rr2, err := res.formatA("google.com.", "2.2.2.2") if err != nil { - return nil, 0, err + return nil, err } msg := &dns.Msg{Answer: []dns.RR{rr1, rr2}} msg.SetReply(m) - return msg, 0, nil - }) + return msg, nil + } for i, tt := range []struct { dns.HandlerFunc