Skip to content

Commit

Permalink
Use workers instead spawning goroutines for each incoming DNS request
Browse files Browse the repository at this point in the history
  • Loading branch information
Uladzimir Trehubenka committed Dec 12, 2017
1 parent 9fc4eb2 commit 9e70f9e
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 20 deletions.
105 changes: 105 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package dns

import (
"fmt"
"net"
"sync"
"testing"
)

func BenchmarkServer(b *testing.B) {
HandleFunc(".", erraticHandler)
defer HandleRemove(".")
benchmark(0, b)
for _, workers := range []int{50, 100, 200, 1000} {
benchmark(workers, b)
}
}

func benchmark(workers int, b *testing.B) {
s, addr, err := runLocalUDPServer(workers)
if err != nil {
b.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

m := new(Msg)
m.SetQuestion("domain.local.", TypeA)

conn, err := net.Dial("udp", addr)
if err != nil {
b.Fatalf("client Dial() failed: %v", err)
}
defer conn.Close()

test := fmt.Sprintf("%d_workers", workers)
b.Run(test, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err = exchange(conn, m)
if err != nil {
b.Fatalf("exchange() failed: %v", err)
}
}
})
}

var (
rrA, _ = NewRR(". IN 0 A 192.0.2.53")
rrAAAA, _ = NewRR(". IN 0 AAAA 2001:DB8::53")
)

func erraticHandler(w ResponseWriter, r *Msg) {
r.Response = true

switch r.Question[0].Qtype {
case TypeA:
rr := *(rrA.(*A))
rr.Header().Name = r.Question[0].Name
r.Answer = []RR{&rr}
r.Rcode = RcodeSuccess
case TypeAAAA:
rr := *(rrAAAA.(*AAAA))
rr.Header().Name = r.Question[0].Name
r.Answer = []RR{&rr}
r.Rcode = RcodeSuccess
default:
r.Rcode = RcodeServerFailure
}

w.WriteMsg(r)
}

func runLocalUDPServer(workers int) (*Server, string, error) {
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
return nil, "", err
}

var wg sync.WaitGroup
wg.Add(1)
server := &Server{
PacketConn: pc,
NotifyStartedFunc: wg.Done,
Workers: workers,
}

go func() {
server.ActivateAndServe()
pc.Close()
}()

wg.Wait()
return server, pc.LocalAddr().String(), nil
}

func exchange(conn net.Conn, m *Msg) (r *Msg, err error) {
c := Conn{Conn: conn}
if err = c.WriteMsg(m); err != nil {
return nil, err
}
r, err = c.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err
}
94 changes: 74 additions & 20 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ type DecorateReader func(Reader) Reader
// Implementations should never return a nil Writer.
type DecorateWriter func(Writer) Writer

// A message defines struct that is passed to worker
type message struct {
m []byte
u *net.UDPConn
s *SessionUDP
t net.Conn
}

// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
Expand Down Expand Up @@ -296,19 +304,42 @@ type Server struct {
DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter

// Messages queue
queue chan *message
// Number of workers, if set to zero - use spawn goroutines instead
Workers int
// Shutdown handling
lock sync.RWMutex
started bool
}

func (srv *Server) startWorkers() {
srv.queue = make(chan *message)
for i := 0; i < srv.Workers; i++ {
go func() {
for msg := range srv.queue {
srv.serve(msg)
}
}()
}
}

// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.started {
return &Error{err: "server already started"}
}

if srv.Handler == nil {
srv.Handler = DefaultServeMux
}

if srv.Workers > 0 {
srv.startWorkers()
}

addr := srv.Addr
if addr == "" {
addr = ":domain"
Expand Down Expand Up @@ -380,6 +411,15 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}

if srv.Handler == nil {
srv.Handler = DefaultServeMux
}

if srv.Workers > 0 {
srv.startWorkers()
}

pConn := srv.PacketConn
l := srv.Listener
if pConn != nil {
Expand Down Expand Up @@ -420,6 +460,10 @@ func (srv *Server) Shutdown() error {
srv.started = false
srv.lock.Unlock()

if srv.Workers > 0 {
close(srv.queue)
}

if srv.PacketConn != nil {
srv.PacketConn.Close()
}
Expand All @@ -439,7 +483,6 @@ func (srv *Server) getReadTimeout() time.Duration {
}

// serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()

Expand All @@ -452,10 +495,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
reader = srv.DecorateReader(reader)
}

handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
Expand All @@ -476,12 +515,16 @@ func (srv *Server) serveTCP(l net.Listener) error {
if err != nil {
continue
}
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
msg := &message{m, nil, nil, rw}
if srv.Workers > 0 {
srv.queue <- msg
} else {
go srv.serve(msg)
}
}
}

// serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()

Expand All @@ -494,10 +537,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}

handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
Expand All @@ -511,13 +550,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if err != nil {
continue
}
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
msg := &message{m, l, s, nil}
if srv.Workers > 0 {
srv.queue <- msg
} else {
go srv.serve(msg)
}
}
return nil
}

// Serve a new connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
func (srv *Server) serve(in *message) {
var a net.Addr

if in.s != nil {
a = in.s.RemoteAddr()
}
if in.t != nil {
a = in.t.RemoteAddr()
}

w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand All @@ -532,7 +586,7 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
}
Redo:
req := new(Msg)
err := req.Unpack(m)
err := req.Unpack(in.m)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
Expand All @@ -550,12 +604,12 @@ Redo:
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
}
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
srv.Handler.ServeDNS(w, req) // Writes back to the client

Exit:
if w.tcp == nil {
Expand All @@ -570,15 +624,15 @@ Exit:
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
if in.u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
in.m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
Expand Down

0 comments on commit 9e70f9e

Please sign in to comment.