From c2b1d5489bde91d854ae5616500ad7b1c5748bce Mon Sep 17 00:00:00 2001 From: Uladzimir Trehubenka Date: Tue, 14 Nov 2017 18:42:21 +0300 Subject: [PATCH] Use workers instead spawning goroutines for each incoming DNS request --- benchmark_test.go | 90 +++++++++++++++++ server.go | 247 +++++++++++++++++++++++++--------------------- 2 files changed, 225 insertions(+), 112 deletions(-) create mode 100644 benchmark_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000000..63afbeede5 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,90 @@ +package dns + +import ( + "fmt" + "net" + "sync" + "testing" + "time" +) + +func BenchmarkServer(b *testing.B) { + HandleFunc(".", erraticHandler) + defer HandleRemove(".") + for _, i := range []int{0, 50, 100, 200, 1000} { + benchmark(i, b) + } +} + +func benchmark(workers int, b *testing.B) { + s, addr, err := runLocalUDPServer(workers) + if err != nil { + b.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("domain.local.", TypeA) + + test := fmt.Sprintf("%d_workers", workers) + b.Run(test, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := Exchange(m, addr) + if err != nil { + b.Fatalf("Exchange() failed: %v", err) + } + } + }) +} + +var ( + rrA, _ = NewRR(". IN 0 A 192.0.2.53") + rrAAAA, _ = NewRR(". IN 0 AAAA 2001:DB8::53") +) + +func erraticHandler(w ResponseWriter, r *Msg) { + r.Response = true + + switch r.Question[0].Qtype { + case TypeA: + rr := *(rrA.(*A)) + rr.Header().Name = r.Question[0].Name + r.Answer = []RR{&rr} + r.Rcode = RcodeSuccess + case TypeAAAA: + rr := *(rrAAAA.(*AAAA)) + rr.Header().Name = r.Question[0].Name + r.Answer = []RR{&rr} + r.Rcode = RcodeSuccess + default: + r.Rcode = RcodeServerFailure + } + + w.WriteMsg(r) +} + +func runLocalUDPServer(workers int) (*Server, string, error) { + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + return nil, "", err + } + + var wg sync.WaitGroup + wg.Add(1) + server := &Server{ + PacketConn: pc, + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, + NotifyStartedFunc: wg.Done, + Handler: DefaultServeMux, + Workers: workers, + } + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + wg.Wait() + return server, pc.LocalAddr().String(), nil +} diff --git a/server.go b/server.go index ee7e256fbc..aa9928139a 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) @@ -98,6 +99,7 @@ func failedHandler() Handler { return HandlerFunc(HandleFailed) } // for incoming queries. func ListenAndServe(addr string, network string, handler Handler) error { server := &Server{Addr: addr, Net: network, Handler: handler} + return server.ListenAndServe() } @@ -129,6 +131,7 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { // Invoke handler for incoming queries. func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { server := &Server{Listener: l, PacketConn: p, Handler: handler} + return server.ActivateAndServe() } @@ -262,6 +265,14 @@ type DecorateReader func(Reader) Reader // Implementations should never return a nil Writer. type DecorateWriter func(Writer) Writer +// A message defines struct that is passed to worker +type message struct { + m []byte + u *net.UDPConn + s *SessionUDP + t net.Conn +} + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -296,19 +307,51 @@ type Server struct { DecorateReader DecorateReader // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter + // Graceful shutdown handling + running int32 + // Messages queue + queue chan *message + // Number of workers, if set to zero - use spawn goroutines instead + Workers int +} + +func (srv *Server) start() { + atomic.StoreInt32(&(srv.running), int32(1)) +} - // Shutdown handling - lock sync.RWMutex - started bool +func (srv *Server) stop() { + atomic.StoreInt32(&(srv.running), int32(0)) +} + +func (srv *Server) isRunning() bool { + return atomic.LoadInt32(&(srv.running)) == 1 +} + +func (srv *Server) startWorkers() { + srv.queue = make(chan *message) + for i := 0; i < srv.Workers; i++ { + go func() { + for msg := range srv.queue { + srv.serve(msg) + } + }() + } } // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { - srv.lock.Lock() - defer srv.lock.Unlock() - if srv.started { + if srv.isRunning() { return &Error{err: "server already started"} } + + if srv.Handler == nil { + srv.Handler = DefaultServeMux + } + + if srv.Workers > 0 { + srv.startWorkers() + } + addr := srv.Addr if addr == "" { addr = ":domain" @@ -327,10 +370,7 @@ func (srv *Server) ListenAndServe() error { return err } srv.Listener = l - srv.started = true - srv.lock.Unlock() err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top return err case "tcp-tls", "tcp4-tls", "tcp6-tls": network := "tcp" @@ -345,10 +385,7 @@ func (srv *Server) ListenAndServe() error { return err } srv.Listener = l - srv.started = true - srv.lock.Unlock() err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top return err case "udp", "udp4", "udp6": a, err := net.ResolveUDPAddr(srv.Net, addr) @@ -363,10 +400,7 @@ func (srv *Server) ListenAndServe() error { return e } srv.PacketConn = l - srv.started = true - srv.lock.Unlock() err = srv.serveUDP(l) - srv.lock.Lock() // to satisfy the defer at the top return err } return &Error{err: "bad network"} @@ -375,11 +409,18 @@ func (srv *Server) ListenAndServe() error { // ActivateAndServe starts a nameserver with the PacketConn or Listener // configured in *Server. Its main use is to start a server from systemd. func (srv *Server) ActivateAndServe() error { - srv.lock.Lock() - defer srv.lock.Unlock() - if srv.started { + if srv.isRunning() { return &Error{err: "server already started"} } + + if srv.Handler == nil { + srv.Handler = DefaultServeMux + } + + if srv.Workers > 0 { + srv.startWorkers() + } + pConn := srv.PacketConn l := srv.Listener if pConn != nil { @@ -392,18 +433,12 @@ func (srv *Server) ActivateAndServe() error { if e := setUDPSocketOptions(t); e != nil { return e } - srv.started = true - srv.lock.Unlock() e := srv.serveUDP(t) - srv.lock.Lock() // to satisfy the defer at the top return e } } if l != nil { - srv.started = true - srv.lock.Unlock() e := srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top return e } return &Error{err: "bad listeners"} @@ -412,14 +447,13 @@ func (srv *Server) ActivateAndServe() error { // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and // ActivateAndServe will return. func (srv *Server) Shutdown() error { - srv.lock.Lock() - if !srv.started { - srv.lock.Unlock() + if !srv.isRunning() { return &Error{err: "server not started"} } - srv.started = false - srv.lock.Unlock() - + srv.stop() + if srv.Workers > 0 { + close(srv.queue) + } if srv.PacketConn != nil { srv.PacketConn.Close() } @@ -439,8 +473,8 @@ func (srv *Server) getReadTimeout() time.Duration { } // serveTCP starts a TCP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveTCP(l net.Listener) error { + srv.start() defer l.Close() if srv.NotifyStartedFunc != nil { @@ -452,13 +486,9 @@ func (srv *Server) serveTCP(l net.Listener) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here - for { + for srv.isRunning() { rw, err := l.Accept() if err != nil { if neterr, ok := err.(net.Error); ok && neterr.Temporary() { @@ -467,22 +497,22 @@ func (srv *Server) serveTCP(l net.Listener) error { return err } m, err := reader.ReadTCP(rw, rtimeout) - srv.lock.RLock() - if !srv.started { - srv.lock.RUnlock() - return nil - } - srv.lock.RUnlock() if err != nil { continue } - go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + msg := &message{m, nil, nil, rw} + if srv.Workers > 0 { + srv.queue <- msg + } else { + go srv.serve(msg) + } } + return nil } // serveUDP starts a UDP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveUDP(l *net.UDPConn) error { + srv.start() defer l.Close() if srv.NotifyStartedFunc != nil { @@ -494,94 +524,87 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here - for { + for srv.isRunning() { m, s, err := reader.ReadUDP(l, rtimeout) - srv.lock.RLock() - if !srv.started { - srv.lock.RUnlock() - return nil - } - srv.lock.RUnlock() if err != nil { continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + msg := &message{m, l, s, nil} + if srv.Workers > 0 { + srv.queue <- msg + } else { + go srv.serve(msg) + } } + return nil } // Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +func (srv *Server) serve(in *message) { + var a net.Addr + var reader Reader + var idleTimeout time.Duration + + if in.s != nil { + a = in.s.RemoteAddr() + } else if in.t != nil { + a = in.t.RemoteAddr() + reader = Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + idleTimeout = tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() + } + } + + w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { w.writer = w } - q := 0 // counter for the amount of TCP queries we get - - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) - } -Redo: - req := new(Msg) - err := req.Unpack(m) - if err != nil { // Send a FormatError back - x := new(Msg) - x.SetRcodeFormatError(req) - w.WriteMsg(x) - goto Exit - } - if !srv.Unsafe && req.Response { - goto Exit - } - - w.tsigStatus = nil - if w.tsigSecret != nil { - if t := req.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.tsigSecret[secret]; !ok { - w.tsigStatus = ErrKeyAlg + for q := 0; q < maxTCPQueries; q++ { // TODO(miek): make this number configurable? + req := new(Msg) + err := req.Unpack(in.m) + if err != nil { // Send a FormatError back + x := new(Msg) + x.SetRcodeFormatError(req) + w.WriteMsg(x) + } else if !srv.Unsafe && req.Response { + // skip call handler + } else { + w.tsigStatus = nil + if w.tsigSecret != nil { + if t := req.IsTsig(); t != nil { + secret := t.Hdr.Name + if _, ok := w.tsigSecret[secret]; !ok { + w.tsigStatus = ErrKeyAlg + } + w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false) + w.tsigTimersOnly = false + w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + } } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) - w.tsigTimersOnly = false - w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + srv.Handler.ServeDNS(w, req) // Writes back to the client } - } - h.ServeDNS(w, req) // Writes back to the client -Exit: - if w.tcp == nil { - return - } - // TODO(miek): make this number configurable? - if q > maxTCPQueries { // close socket after this many queries - w.Close() - return - } + if in.u != nil { + break // no loop for udp + } - if w.hijacked { - return // client calls Close() - } - if u != nil { // UDP, "close" and return - w.Close() - return - } - idleTimeout := tcpIdleTimeout - if srv.IdleTimeout != nil { - idleTimeout = srv.IdleTimeout() - } - m, err = reader.ReadTCP(w.tcp, idleTimeout) - if err == nil { - q++ - goto Redo + if w.tcp == nil || w.hijacked { + return // client calls Close() + } + + in.m, err = reader.ReadTCP(w.tcp, idleTimeout) + if err != nil { + break // tcp read next packet failed + } } w.Close() return