Skip to content

Commit

Permalink
Use workers instead spawning goroutines for each incoming DNS request
Browse files Browse the repository at this point in the history
  • Loading branch information
Uladzimir Trehubenka committed Nov 14, 2017
1 parent 9fc4eb2 commit 1f26fee
Showing 1 changed file with 79 additions and 14 deletions.
93 changes: 79 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ import (
"time"
)

// Maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
const (
// Maximum number of TCP queries before we close the socket.
maxTCPQueries = 128
// Maximum number of incoming DNS messages in queue.
maxQueueSize = 1000000
// Maximum number of workers.
maxWorkers = 100
)

// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
Expand Down Expand Up @@ -97,13 +103,21 @@ func failedHandler() Handler { return HandlerFunc(HandleFailed) }
// ListenAndServe Starts a server on address and network specified Invoke handler
// for incoming queries.
func ListenAndServe(addr string, network string, handler Handler) error {
if handler == nil {
handler = DefaultServeMux
}

server := &Server{Addr: addr, Net: network, Handler: handler}
return server.ListenAndServe()
}

// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
// http://golang.org/pkg/net/http/#ListenAndServeTLS
func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
if handler == nil {
handler = DefaultServeMux
}

cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
Expand All @@ -128,6 +142,10 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
// If both l and p are not nil only p will be used.
// Invoke handler for incoming queries.
func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
if handler == nil {
handler = DefaultServeMux
}

server := &Server{Listener: l, PacketConn: p, Handler: handler}
return server.ActivateAndServe()
}
Expand Down Expand Up @@ -262,6 +280,14 @@ type DecorateReader func(Reader) Reader
// Implementations should never return a nil Writer.
type DecorateWriter func(Writer) Writer

// A message defines struct that is passed to worker
type message struct {
m []byte
u *net.UDPConn
s *SessionUDP
t net.Conn
}

// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
Expand Down Expand Up @@ -300,6 +326,28 @@ type Server struct {
// Shutdown handling
lock sync.RWMutex
started bool

// Messages queue
queue chan *message
// Quit chan
quit chan struct{}
}

func (srv *Server) startWorkers() {
srv.queue = make(chan *message, maxQueueSize)
srv.quit = make(chan struct{})
for i := 0; i < maxWorkers; i++ {
go func() {
for {
select {
case msg := <-srv.queue:
srv.serve(msg)
case <-srv.quit:
return
}
}
}()
}
}

// ListenAndServe starts a nameserver on the configured address in *Server.
Expand All @@ -309,6 +357,9 @@ func (srv *Server) ListenAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}

srv.startWorkers()

addr := srv.Addr
if addr == "" {
addr = ":domain"
Expand Down Expand Up @@ -380,6 +431,9 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}

srv.startWorkers()

pConn := srv.PacketConn
l := srv.Listener
if pConn != nil {
Expand Down Expand Up @@ -426,6 +480,7 @@ func (srv *Server) Shutdown() error {
if srv.Listener != nil {
srv.Listener.Close()
}
close(srv.quit)
return nil
}

Expand Down Expand Up @@ -476,7 +531,7 @@ func (srv *Server) serveTCP(l net.Listener) error {
if err != nil {
continue
}
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
srv.queue <- &message{m, nil, nil, rw}
}
}

Expand Down Expand Up @@ -511,13 +566,20 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if err != nil {
continue
}
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
srv.queue <- &message{m, l, s, nil}
}
}

// Serve a new connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
func (srv *Server) serve(in *message) {
var a net.Addr
if in.t != nil {
a = in.t.RemoteAddr()
} else if in.s != nil {
a = in.s.RemoteAddr()
}

w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand All @@ -526,13 +588,16 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses

q := 0 // counter for the amount of TCP queries we get

reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
var reader Reader
if w.tcp != nil {
reader = Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
}
Redo:
req := new(Msg)
err := req.Unpack(m)
err := req.Unpack(in.m)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
Expand All @@ -550,12 +615,12 @@ Redo:
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
}
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
srv.Handler.ServeDNS(w, req) // Writes back to the client

Exit:
if w.tcp == nil {
Expand All @@ -570,15 +635,15 @@ Exit:
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
if w.udp != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
in.m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
Expand Down

0 comments on commit 1f26fee

Please sign in to comment.