From 1f26fee0151b2044d37b49ec657914b583e3b54d Mon Sep 17 00:00:00 2001 From: Uladzimir Trehubenka Date: Tue, 14 Nov 2017 18:42:21 +0300 Subject: [PATCH] Use workers instead spawning goroutines for each incoming DNS request --- server.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 14 deletions(-) diff --git a/server.go b/server.go index ee7e256fbc..78a656e28e 100644 --- a/server.go +++ b/server.go @@ -12,8 +12,14 @@ import ( "time" ) -// Maximum number of TCP queries before we close the socket. -const maxTCPQueries = 128 +const ( + // Maximum number of TCP queries before we close the socket. + maxTCPQueries = 128 + // Maximum number of incoming DNS messages in queue. + maxQueueSize = 1000000 + // Maximum number of workers. + maxWorkers = 100 +) // Handler is implemented by any value that implements ServeDNS. type Handler interface { @@ -97,6 +103,10 @@ func failedHandler() Handler { return HandlerFunc(HandleFailed) } // ListenAndServe Starts a server on address and network specified Invoke handler // for incoming queries. func ListenAndServe(addr string, network string, handler Handler) error { + if handler == nil { + handler = DefaultServeMux + } + server := &Server{Addr: addr, Net: network, Handler: handler} return server.ListenAndServe() } @@ -104,6 +114,10 @@ func ListenAndServe(addr string, network string, handler Handler) error { // ListenAndServeTLS acts like http.ListenAndServeTLS, more information in // http://golang.org/pkg/net/http/#ListenAndServeTLS func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { + if handler == nil { + handler = DefaultServeMux + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return err @@ -128,6 +142,10 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { // If both l and p are not nil only p will be used. // Invoke handler for incoming queries. func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { + if handler == nil { + handler = DefaultServeMux + } + server := &Server{Listener: l, PacketConn: p, Handler: handler} return server.ActivateAndServe() } @@ -262,6 +280,14 @@ type DecorateReader func(Reader) Reader // Implementations should never return a nil Writer. type DecorateWriter func(Writer) Writer +// A message defines struct that is passed to worker +type message struct { + m []byte + u *net.UDPConn + s *SessionUDP + t net.Conn +} + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -300,6 +326,28 @@ type Server struct { // Shutdown handling lock sync.RWMutex started bool + + // Messages queue + queue chan *message + // Quit chan + quit chan struct{} +} + +func (srv *Server) startWorkers() { + srv.queue = make(chan *message, maxQueueSize) + srv.quit = make(chan struct{}) + for i := 0; i < maxWorkers; i++ { + go func() { + for { + select { + case msg := <-srv.queue: + srv.serve(msg) + case <-srv.quit: + return + } + } + }() + } } // ListenAndServe starts a nameserver on the configured address in *Server. @@ -309,6 +357,9 @@ func (srv *Server) ListenAndServe() error { if srv.started { return &Error{err: "server already started"} } + + srv.startWorkers() + addr := srv.Addr if addr == "" { addr = ":domain" @@ -380,6 +431,9 @@ func (srv *Server) ActivateAndServe() error { if srv.started { return &Error{err: "server already started"} } + + srv.startWorkers() + pConn := srv.PacketConn l := srv.Listener if pConn != nil { @@ -426,6 +480,7 @@ func (srv *Server) Shutdown() error { if srv.Listener != nil { srv.Listener.Close() } + close(srv.quit) return nil } @@ -476,7 +531,7 @@ func (srv *Server) serveTCP(l net.Listener) error { if err != nil { continue } - go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + srv.queue <- &message{m, nil, nil, rw} } } @@ -511,13 +566,20 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if err != nil { continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + srv.queue <- &message{m, l, s, nil} } } // Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +func (srv *Server) serve(in *message) { + var a net.Addr + if in.t != nil { + a = in.t.RemoteAddr() + } else if in.s != nil { + a = in.s.RemoteAddr() + } + + w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -526,13 +588,16 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses q := 0 // counter for the amount of TCP queries we get - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) + var reader Reader + if w.tcp != nil { + reader = Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } } Redo: req := new(Msg) - err := req.Unpack(m) + err := req.Unpack(in.m) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) @@ -550,12 +615,12 @@ Redo: if _, ok := w.tsigSecret[secret]; !ok { w.tsigStatus = ErrKeyAlg } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) + w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } - h.ServeDNS(w, req) // Writes back to the client + srv.Handler.ServeDNS(w, req) // Writes back to the client Exit: if w.tcp == nil { @@ -570,7 +635,7 @@ Exit: if w.hijacked { return // client calls Close() } - if u != nil { // UDP, "close" and return + if w.udp != nil { // UDP, "close" and return w.Close() return } @@ -578,7 +643,7 @@ Exit: if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() } - m, err = reader.ReadTCP(w.tcp, idleTimeout) + in.m, err = reader.ReadTCP(w.tcp, idleTimeout) if err == nil { q++ goto Redo