Skip to content

Commit

Permalink
Support TsigProvider for Server and Transfer (#1331)
Browse files Browse the repository at this point in the history
Automatically submitted.
  • Loading branch information
tmthrgd authored Feb 5, 2022
1 parent 51afb90 commit 33e6400
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 49 deletions.
32 changes: 12 additions & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ type Conn struct {
tsigRequestMAC string
}

func (co *Conn) tsigProvider() TsigProvider {
if co.TsigProvider != nil {
return co.TsigProvider
}
// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
return tsigSecretProvider(co.TsigSecret)
}

// A Client defines parameters for a DNS client.
type Client struct {
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
Expand Down Expand Up @@ -271,15 +279,8 @@ func (co *Conn) ReadMsg() (*Msg, error) {
return m, err
}
if t := m.IsTsig(); t != nil {
if co.TsigProvider != nil {
err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
} else {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
}
return m, err
}
Expand Down Expand Up @@ -356,17 +357,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
mac := ""
if co.TsigProvider != nil {
out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
} else {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return ErrSecret
}
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
// Set for the next read, although only used in zone transfers
co.tsigRequestMAC = mac
// Set tsigRequestMAC for the next read, although only used in zone transfers.
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
} else {
out, err = m.Pack()
}
Expand Down
42 changes: 25 additions & 17 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ type response struct {
tsigTimersOnly bool
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
tsigProvider TsigProvider
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}

// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
Expand Down Expand Up @@ -211,6 +211,8 @@ type Server struct {
WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration
// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
TsigProvider TsigProvider
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string
// If NotifyStartedFunc is set it is called once the server has started listening.
Expand Down Expand Up @@ -238,6 +240,16 @@ type Server struct {
udpPool sync.Pool
}

func (srv *Server) tsigProvider() TsigProvider {
if srv.TsigProvider != nil {
return srv.TsigProvider
}
if srv.TsigSecret != nil {
return tsigSecretProvider(srv.TsigSecret)
}
return nil
}

func (srv *Server) isStarted() bool {
srv.lock.RLock()
started := srv.started
Expand Down Expand Up @@ -526,7 +538,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {

// Serve a new TCP connection.
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand Down Expand Up @@ -581,7 +593,7 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {

// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand Down Expand Up @@ -632,15 +644,11 @@ func (srv *Server) serveDNS(m []byte, w *response) {
}

w.tsigStatus = nil
if w.tsigSecret != nil {
if w.tsigProvider != nil {
if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(m, secret, "", false)
} else {
w.tsigStatus = ErrSecret
}
w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
w.tsigRequestMAC = t.MAC
}
}

Expand Down Expand Up @@ -718,9 +726,9 @@ func (w *response) WriteMsg(m *Msg) (err error) {
}

var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
Expand Down
18 changes: 18 additions & 0 deletions tsig.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,24 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
return nil
}

type tsigSecretProvider map[string]string

func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
key, ok := ts[t.Hdr.Name]
if !ok {
return nil, ErrSecret
}
return tsigHMACProvider(key).Generate(msg, t)
}

func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error {
key, ok := ts[t.Hdr.Name]
if !ok {
return ErrSecret
}
return tsigHMACProvider(key).Verify(msg, t)
}

// TSIG is the RR the holds the transaction signature of a message.
// See RFC 2845 and RFC 4635.
type TSIG struct {
Expand Down
27 changes: 16 additions & 11 deletions xfr.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,22 @@ type Transfer struct {
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
tsigTimersOnly bool
}

// Think we need to away to stop the transfer
func (t *Transfer) tsigProvider() TsigProvider {
if t.TsigProvider != nil {
return t.TsigProvider
}
if t.TsigSecret != nil {
return tsigSecretProvider(t.TsigSecret)
}
return nil
}

// TODO: Think we need to away to stop the transfer

// In performs an incoming transfer with the server in a.
// If you would like to set the source IP, or some other attribute
Expand Down Expand Up @@ -224,12 +235,9 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
if err := m.Unpack(p); err != nil {
return nil, err
}
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return m, ErrSecret
}
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
t.tsigRequestMAC = ts.MAC
}
return m, err
Expand All @@ -238,11 +246,8 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
// WriteMsg writes a message through the transfer connection t.
func (t *Transfer) WriteMsg(m *Msg) (err error) {
var out []byte
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return ErrSecret
}
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
} else {
out, err = m.Pack()
}
Expand Down
56 changes: 55 additions & 1 deletion xfr_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package dns

import "testing"
import (
"testing"
"time"
)

var (
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
Expand Down Expand Up @@ -127,3 +130,54 @@ func axfrTestingSuite(t *testing.T, addrstr string) {
}
}
}

func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) {
tr := new(Transfer)
m := new(Msg)
var err error
tr.Conn, err = Dial("tcp", addrstr)
if err != nil {
t.Fatal("failed to dial", err)
}
tr.TsigProvider = provider
m.SetAxfr("miek.nl.")
m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix())

c, err := tr.In(m, addrstr)
if err != nil {
t.Fatal("failed to zone transfer in", err)
}

var records []RR
for msg := range c {
if msg.Error != nil {
t.Fatal(msg.Error)
}
records = append(records, msg.RR...)
}

if len(records) != len(xfrTestData) {
t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData)
}

for i, rr := range records {
if !IsDuplicate(rr, xfrTestData[i]) {
t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData)
}
}
}

func TestCustomTsigProvider(t *testing.T) {
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
defer HandleRemove("miek.nl.")

s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
srv.TsigProvider = tsigSecretProvider(tsigSecret)
})
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
defer s.Shutdown()

axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret))
}

0 comments on commit 33e6400

Please sign in to comment.