Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Transparent proxying of external queries #307

Merged
merged 1 commit into from
Oct 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 9 additions & 60 deletions exchanger/exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package exchanger

import (
"log"
"net"
"time"

"github.com/mesosphere/mesos-dns/logging"
Expand Down Expand Up @@ -36,24 +35,6 @@ func Decorate(ex Exchanger, ds ...Decorator) Exchanger {
return decorated
}

// Pred is a predicate function type for dns.Msgs.
type Pred func(*dns.Msg) bool

// While returns an Exchanger which attempts the given Exchangers while the given
// predicate function returns true for the returned dns.Msg, an error is returned,
// or all Exchangers are attempted, in which case the return values of the last
// one are returned.
func While(p Pred, exs ...Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for _, ex := range exs {
if r, rtt, err = ex.Exchange(m, a); err != nil || !p(r) {
break
}
}
return
})
}

// ErrorLogging returns a Decorator which logs an Exchanger's errors to the given
// logger.
func ErrorLogging(l *log.Logger) Decorator {
Expand All @@ -70,50 +51,18 @@ func ErrorLogging(l *log.Logger) Decorator {
}

// Instrumentation returns a Decorator which instruments an Exchanger with the given
// counter.
func Instrumentation(c logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
defer c.Inc()
return ex.Exchange(m, a)
})
}
}

// A Recurser returns the addr (host:port) of the next DNS server to recurse a
// Msg to. Empty returns signal that further recursion isn't possible or needed.
type Recurser func(*dns.Msg) string

// Recurse is the default Mesos-DNS Recurser which returns an addr (host:port)
// only when the given dns.Msg doesn't contain authoritative answers and has at
// least one SOA record in its NS section.
func Recurse(r *dns.Msg) string {
if r.Authoritative && len(r.Answer) > 0 {
return ""
}

for _, ns := range r.Ns {
if soa, ok := ns.(*dns.SOA); ok {
return net.JoinHostPort(soa.Ns, "53")
}
}

return ""
}

// Recursion returns a Decorator which recurses until the given Recurser returns
// an empty string or max attempts have been reached.
func Recursion(max int, rec Recurser) Decorator {
// counters.
func Instrumentation(total, success, failure logging.Counter) Decorator {
return func(ex Exchanger) Exchanger {
return Func(func(m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) {
for i := 0; i <= max; i++ {
if r, rtt, err = ex.Exchange(m, a); err != nil {
break
} else if a = rec(r); a == "" {
break
defer func() {
if total.Inc(); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks strange because the increment of success depends on a negative outcome of total.Inc(). unless I'm missing something, total.Inc() should really be on its own line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total.Inc() has no impact on the conditional. It runs unconditionally.

failure.Inc()
} else {
success.Inc()
}
}
return r, rtt, err
}()
return ex.Exchange(m, a)
})
}
}
186 changes: 40 additions & 146 deletions exchanger/exchanger_test.go
Original file line number Diff line number Diff line change
@@ -1,169 +1,63 @@
package exchanger

import (
"bytes"
"errors"
"net"
"reflect"
"log"
"testing"
"time"

. "github.com/mesosphere/mesos-dns/dnstest"
"github.com/mesosphere/mesos-dns/logging"
"github.com/miekg/dns"
)

func TestWhile(t *testing.T) {
for i, tt := range []struct {
pred Pred
exs []Exchanger
want exchanged
}{
{ // error
nil,
stubs(exchanged{nil, 0, errors.New("foo")}),
exchanged{nil, 0, errors.New("foo")},
},
{ // always true predicate
func(*dns.Msg) bool { return true },
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}),
exchanged{nil, 1, nil},
},
{ // nil exchangers
nil,
nil,
exchanged{nil, 0, nil},
},
{ // empty exchangers
nil,
stubs(),
exchanged{nil, 0, nil},
},
{ // false predicate
func(calls int) Pred {
return func(*dns.Msg) bool {
calls++
return calls != 2
}
}(0),
stubs(exchanged{nil, 0, nil}, exchanged{nil, 1, nil}, exchanged{nil, 2, nil}),
exchanged{nil, 1, nil},
},
} {
var got exchanged
got.m, got.rtt, got.err = While(tt.pred, tt.exs...).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestErrorLogging(t *testing.T) {
{ // with error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := "timeout: exchanging (*dns.Msg)(nil) with \"1.2.3.4\"\n"
if got := buf.String(); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
{ // no error
var buf bytes.Buffer
_, _, _ = ErrorLogging(log.New(&buf, "", 0))(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func TestRecurse(t *testing.T) {
for i, tt := range []struct {
*dns.Msg
want string
}{
{ // Authoritative with answers
Message(
Header(true, 0),
Answers(
A(RRHeader("localhost", dns.TypeA, 0), net.IPv6loopback.To4()),
),
NSs(
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"",
},
{ // Authoritative, empty answers, no SOA records
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "next"),
),
),
"",
},
{ // Not authoritative, no SOA record
Message(Header(false, 0)),
"",
},
{ // Not authoritative, one SOA record
Message(
Header(false, 0),
NSs(SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0)),
),
"next:53",
},
{ // Authoritative, empty answers, one SOA record
Message(
Header(true, 0),
NSs(
NS(RRHeader("", dns.TypeNS, 0), "foo"),
SOA(RRHeader("", dns.TypeSOA, 0), "next", "", 0),
),
),
"next:53",
},
} {
if got := Recurse(tt.Msg); got != tt.want {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
if got, want := buf.String(), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}

func TestRecursion(t *testing.T) {
for i, tt := range []struct {
max int
rec Recurser
ex Exchanger
want exchanged
}{
{
0,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 1})...),
exchanged{rtt: 1},
},
{
1,
func(*dns.Msg) string { return "next" },
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
{
0,
nil,
seq(stubs(exchanged{err: errors.New("foo")})...),
exchanged{err: errors.New("foo")},
},
{
2,
func(calls int) Recurser {
return func(*dns.Msg) string {
if calls++; calls <= 1 {
return "next"
}
return ""
}
}(0),
seq(stubs(exchanged{rtt: 0}, exchanged{rtt: 1}, exchanged{rtt: 2})...),
exchanged{rtt: 1},
},
} {
var got exchanged
got.m, got.rtt, got.err = Recursion(tt.max, tt.rec)(tt.ex).Exchange(nil, "")
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test #%d: got: %v, want: %v", i, got, tt.want)
func TestInstrumentation(t *testing.T) {
{ // with error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{err: errors.New("timeout")})).Exchange(nil, "1.2.3.4")

want := []string{"1", "0", "1"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}
{ // no error
var total, success, failure logging.LogCounter
_, _, _ = Instrumentation(&total, &success, &failure)(
stub(exchanged{})).Exchange(nil, "1.2.3.4")

func seq(exs ...Exchanger) Exchanger {
var i int
return Func(func(m *dns.Msg, a string) (*dns.Msg, time.Duration, error) {
ex := exs[i]
i++
return ex.Exchange(m, a)
})
want := []string{"1", "1", "0"}
for i, c := range []*logging.LogCounter{&total, &success, &failure} {
if got, want := c.String(), want[i]; got != want {
t.Errorf("test #%d: got %q, want %q", i, got, want)
}
}
}
}

func stubs(ed ...exchanged) []Exchanger {
Expand Down
49 changes: 49 additions & 0 deletions exchanger/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package exchanger

import (
"fmt"
"net"

"github.com/miekg/dns"
)

// A Forwarder is a DNS message forwarder that transparently proxies messages
// to DNS servers.
type Forwarder func(*dns.Msg, string) (*dns.Msg, error)

// Forward is an utility method that calls f itself.
func (f Forwarder) Forward(m *dns.Msg, proto string) (*dns.Msg, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't understand the need for this utility func at all since it adds no value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need, that's why it's an utility function. Code reads much better with it. If we define an interface, it would also implement it automatically.

return f(m, proto)
}

// NewForwarder returns a new Forwarder for the given addrs with the given
// Exchangers map which maps network protocols to Exchangers.
//
// Every message will be exchanged with each address until no error is returned.
// If no addresses or no matching protocol exchanger exist, a *ForwardError will
// be returned.
func NewForwarder(addrs []string, exs map[string]Exchanger) Forwarder {
return func(m *dns.Msg, proto string) (r *dns.Msg, err error) {
ex, ok := exs[proto]
if !ok || len(addrs) == 0 {
return nil, &ForwardError{Addrs: addrs, Proto: proto}
}
for _, a := range addrs {
if r, _, err = ex.Exchange(m, net.JoinHostPort(a, "53")); err == nil {
break
}
}
return
}
}

// A ForwardError is returned by Forwarders when they can't forward.
type ForwardError struct {
Addrs []string
Proto string
}

// Error implements the error interface.
func (e ForwardError) Error() string {
return fmt.Sprintf("can't forward to %v over %q", e.Addrs, e.Proto)
}
Loading