diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 000000000..6988f9b7a --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,105 @@ +package dns + +import ( + "fmt" + "net" + "sync" + "testing" +) + +func BenchmarkServer(b *testing.B) { + HandleFunc(".", erraticHandler) + defer HandleRemove(".") + benchmark(0, b) + for _, workers := range []int{50, 100, 200, 1000} { + benchmark(workers, b) + } +} + +func benchmark(workers int, b *testing.B) { + s, addr, err := runLocalUDPServer(workers) + if err != nil { + b.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("domain.local.", TypeA) + + conn, err := net.Dial("udp", addr) + if err != nil { + b.Fatalf("client Dial() failed: %v", err) + } + defer conn.Close() + + test := fmt.Sprintf("%d_workers", workers) + b.Run(test, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err = exchange(conn, m) + if err != nil { + b.Fatalf("exchange() failed: %v", err) + } + } + }) +} + +var ( + rrA, _ = NewRR(". IN 0 A 192.0.2.53") + rrAAAA, _ = NewRR(". IN 0 AAAA 2001:DB8::53") +) + +func erraticHandler(w ResponseWriter, r *Msg) { + r.Response = true + + switch r.Question[0].Qtype { + case TypeA: + rr := *(rrA.(*A)) + rr.Header().Name = r.Question[0].Name + r.Answer = []RR{&rr} + r.Rcode = RcodeSuccess + case TypeAAAA: + rr := *(rrAAAA.(*AAAA)) + rr.Header().Name = r.Question[0].Name + r.Answer = []RR{&rr} + r.Rcode = RcodeSuccess + default: + r.Rcode = RcodeServerFailure + } + + w.WriteMsg(r) +} + +func runLocalUDPServer(workers int) (*Server, string, error) { + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + return nil, "", err + } + + var wg sync.WaitGroup + wg.Add(1) + server := &Server{ + PacketConn: pc, + NotifyStartedFunc: wg.Done, + Workers: workers, + } + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + wg.Wait() + return server, pc.LocalAddr().String(), nil +} + +func exchange(conn net.Conn, m *Msg) (r *Msg, err error) { + c := Conn{Conn: conn} + if err = c.WriteMsg(m); err != nil { + return nil, err + } + r, err = c.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } + return r, err +} diff --git a/server.go b/server.go index ee7e256fb..6fa6696b8 100644 --- a/server.go +++ b/server.go @@ -262,6 +262,14 @@ type DecorateReader func(Reader) Reader // Implementations should never return a nil Writer. type DecorateWriter func(Writer) Writer +// A message defines struct that is passed to worker +type message struct { + m []byte + u *net.UDPConn + s *SessionUDP + t net.Conn +} + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -296,12 +304,26 @@ type Server struct { DecorateReader DecorateReader // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter - + // Messages queue + queue chan *message + // Number of workers, if set to zero - use spawn goroutines instead + Workers int // Shutdown handling lock sync.RWMutex started bool } +func (srv *Server) startWorkers() { + srv.queue = make(chan *message) + for i := 0; i < srv.Workers; i++ { + go func() { + for msg := range srv.queue { + srv.serve(msg) + } + }() + } +} + // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { srv.lock.Lock() @@ -309,6 +331,15 @@ func (srv *Server) ListenAndServe() error { if srv.started { return &Error{err: "server already started"} } + + if srv.Handler == nil { + srv.Handler = DefaultServeMux + } + + if srv.Workers > 0 { + srv.startWorkers() + } + addr := srv.Addr if addr == "" { addr = ":domain" @@ -380,6 +411,15 @@ func (srv *Server) ActivateAndServe() error { if srv.started { return &Error{err: "server already started"} } + + if srv.Handler == nil { + srv.Handler = DefaultServeMux + } + + if srv.Workers > 0 { + srv.startWorkers() + } + pConn := srv.PacketConn l := srv.Listener if pConn != nil { @@ -420,6 +460,10 @@ func (srv *Server) Shutdown() error { srv.started = false srv.lock.Unlock() + if srv.Workers > 0 { + close(srv.queue) + } + if srv.PacketConn != nil { srv.PacketConn.Close() } @@ -439,7 +483,6 @@ func (srv *Server) getReadTimeout() time.Duration { } // serveTCP starts a TCP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveTCP(l net.Listener) error { defer l.Close() @@ -452,10 +495,6 @@ func (srv *Server) serveTCP(l net.Listener) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here for { @@ -476,12 +515,16 @@ func (srv *Server) serveTCP(l net.Listener) error { if err != nil { continue } - go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + msg := &message{m, nil, nil, rw} + if srv.Workers > 0 { + srv.queue <- msg + } else { + go srv.serve(msg) + } } } // serveUDP starts a UDP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveUDP(l *net.UDPConn) error { defer l.Close() @@ -494,10 +537,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here for { @@ -511,13 +550,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if err != nil { continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + msg := &message{m, l, s, nil} + if srv.Workers > 0 { + srv.queue <- msg + } else { + go srv.serve(msg) + } } + return nil } // Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +func (srv *Server) serve(in *message) { + var a net.Addr + + if in.s != nil { + a = in.s.RemoteAddr() + } + if in.t != nil { + a = in.t.RemoteAddr() + } + + w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -532,7 +586,7 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses } Redo: req := new(Msg) - err := req.Unpack(m) + err := req.Unpack(in.m) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) @@ -550,12 +604,12 @@ Redo: if _, ok := w.tsigSecret[secret]; !ok { w.tsigStatus = ErrKeyAlg } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) + w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } - h.ServeDNS(w, req) // Writes back to the client + srv.Handler.ServeDNS(w, req) // Writes back to the client Exit: if w.tcp == nil { @@ -570,7 +624,7 @@ Exit: if w.hijacked { return // client calls Close() } - if u != nil { // UDP, "close" and return + if in.u != nil { // UDP, "close" and return w.Close() return } @@ -578,7 +632,7 @@ Exit: if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() } - m, err = reader.ReadTCP(w.tcp, idleTimeout) + in.m, err = reader.ReadTCP(w.tcp, idleTimeout) if err == nil { q++ goto Redo