diff --git a/acceptfunc_test.go b/acceptfunc_test.go index d40d4e4cd..868154d40 100644 --- a/acceptfunc_test.go +++ b/acceptfunc_test.go @@ -1,6 +1,8 @@ package dns import ( + "encoding/binary" + "net" "testing" ) @@ -33,3 +35,86 @@ func handleNotify(w ResponseWriter, req *Msg) { m.SetReply(req) w.WriteMsg(m) } + +func TestInvalidMsg(t *testing.T) { + HandleFunc("example.org.", func(ResponseWriter, *Msg) { + t.Fatal("the handler must not be called in any of these tests") + }) + s, addrstr, _, err := RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + s.MsgAcceptFunc = func(dh Header) MsgAcceptAction { + switch dh.Id { + case 0x0001: + return MsgAccept + case 0x0002: + return MsgReject + case 0x0003: + return MsgIgnore + case 0x0004: + return MsgRejectNotImplemented + default: + t.Errorf("unexpected ID %x", dh.Id) + return -1 + } + } + + invalidErrors := make(chan error) + s.MsgInvalidFunc = func(m []byte, err error) { + invalidErrors <- err + } + + c, err := net.Dial("tcp", addrstr) + if err != nil { + t.Fatalf("cannot connect to test server: %v", err) + } + + write := func(m []byte) { + var length [2]byte + binary.BigEndian.PutUint16(length[:], uint16(len(m))) + _, err := c.Write(length[:]) + if err != nil { + t.Fatalf("length write failed: %v", err) + } + _, err = c.Write(m) + if err != nil { + t.Fatalf("content write failed: %v", err) + } + } + + /* Message is too short, so there is no header to accept or reject. */ + + tooShortMessage := make([]byte, 11) + tooShortMessage[1] = 0x3 // ID = 3, would be ignored if it were parsable. + + write(tooShortMessage) + // Expect an error to be reported. + <-invalidErrors + + /* Message is accepted but is actually invalid. */ + + badMessage := make([]byte, 13) + badMessage[1] = 0x1 // ID = 1, Accept. + badMessage[5] = 1 // QDCOUNT = 1 + badMessage[12] = 99 // Bad question section. Invalid! + + write(badMessage) + // Expect an error to be reported. + <-invalidErrors + + /* Message is rejected before it can be determined to be invalid. */ + + close(invalidErrors) // A call to InvalidMsgFunc would panic due to the closed chan. + + badMessage[1] = 0x2 // ID = 2, Reject + write(badMessage) + + badMessage[1] = 0x3 // ID = 3, Ignore + write(badMessage) + + badMessage[1] = 0x4 // ID = 4, RejectNotImplemented + write(badMessage) +} diff --git a/server.go b/server.go index 0207d6da2..2f7655645 100644 --- a/server.go +++ b/server.go @@ -188,6 +188,14 @@ type DecorateReader func(Reader) Reader // Implementations should never return a nil Writer. type DecorateWriter func(Writer) Writer +// InvalidMsgFunc is a listener hook for observing incoming messages that were discarded +// because they could not be parsed. +// Every message that is read by a Reader will eventually be provided to the Handler, +// rejected (or ignored) by the MsgAcceptFunc, or passed to this function. +type InvalidMsgFunc func(m []byte, err error) + +func DefaultMsgInvalidFunc(m []byte, err error) {} + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -233,6 +241,8 @@ type Server struct { // AcceptMsgFunc will check the incoming message and will reject it early in the process. // By default DefaultMsgAcceptFunc will be used. MsgAcceptFunc MsgAcceptFunc + // MsgInvalidFunc is optional, will be called if a message is received but cannot be parsed. + MsgInvalidFunc InvalidMsgFunc // Shutdown handling lock sync.RWMutex @@ -277,6 +287,9 @@ func (srv *Server) init() { if srv.MsgAcceptFunc == nil { srv.MsgAcceptFunc = DefaultMsgAcceptFunc } + if srv.MsgInvalidFunc == nil { + srv.MsgInvalidFunc = DefaultMsgInvalidFunc + } if srv.Handler == nil { srv.Handler = DefaultServeMux } @@ -531,6 +544,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error { if cap(m) == srv.UDPSize { srv.udpPool.Put(m[:srv.UDPSize]) } + srv.MsgInvalidFunc(m, ErrShortRead) continue } wg.Add(1) @@ -611,6 +625,7 @@ func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn func (srv *Server) serveDNS(m []byte, w *response) { dh, off, err := unpackMsgHdr(m, 0) if err != nil { + srv.MsgInvalidFunc(m, err) // Let client hang, they are sending crap; any reply can be used to amplify. return } @@ -620,10 +635,12 @@ func (srv *Server) serveDNS(m []byte, w *response) { switch action := srv.MsgAcceptFunc(dh); action { case MsgAccept: - if req.unpack(dh, m, off) == nil { + err := req.unpack(dh, m, off) + if err == nil { break } + srv.MsgInvalidFunc(m, err) fallthrough case MsgReject, MsgRejectNotImplemented: opcode := req.Opcode