Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use workers instead spawning goroutines for each incoming DNS request #565

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package dns

import (
"fmt"
"net"
"sync"
"testing"
)

func BenchmarkServer(b *testing.B) {
HandleFunc(".", erraticHandler)
defer HandleRemove(".")
benchmark(0, b)
for _, workers := range []int{50, 100, 200, 1000} {
benchmark(workers, b)
}
}

func benchmark(workers int, b *testing.B) {
s, addr, err := runLocalUDPServer(workers)
if err != nil {
b.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

m := new(Msg)
m.SetQuestion("domain.local.", TypeA)

conn, err := net.Dial("udp", addr)
if err != nil {
b.Fatalf("client Dial() failed: %v", err)
}
defer conn.Close()

test := fmt.Sprintf("%d_workers", workers)
b.Run(test, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err = exchange(conn, m)
if err != nil {
b.Fatalf("exchange() failed: %v", err)
}
}
})
}

var (
rrA, _ = NewRR(". IN 0 A 192.0.2.53")
rrAAAA, _ = NewRR(". IN 0 AAAA 2001:DB8::53")
)

func erraticHandler(w ResponseWriter, r *Msg) {
r.Response = true

switch r.Question[0].Qtype {
case TypeA:
rr := *(rrA.(*A))
rr.Header().Name = r.Question[0].Name
r.Answer = []RR{&rr}
r.Rcode = RcodeSuccess
case TypeAAAA:
rr := *(rrAAAA.(*AAAA))
rr.Header().Name = r.Question[0].Name
r.Answer = []RR{&rr}
r.Rcode = RcodeSuccess
default:
r.Rcode = RcodeServerFailure
}

w.WriteMsg(r)
}

func runLocalUDPServer(workers int) (*Server, string, error) {
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
return nil, "", err
}

var wg sync.WaitGroup
wg.Add(1)
server := &Server{
PacketConn: pc,
NotifyStartedFunc: wg.Done,
Workers: workers,
}

go func() {
server.ActivateAndServe()
pc.Close()
}()

wg.Wait()
return server, pc.LocalAddr().String(), nil
}

func exchange(conn net.Conn, m *Msg) (r *Msg, err error) {
c := Conn{Conn: conn}
if err = c.WriteMsg(m); err != nil {
return nil, err
}
r, err = c.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err
}
94 changes: 74 additions & 20 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ type DecorateReader func(Reader) Reader
// Implementations should never return a nil Writer.
type DecorateWriter func(Writer) Writer

// A message defines struct that is passed to worker
type message struct {
m []byte
u *net.UDPConn
s *SessionUDP
t net.Conn
}

// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
Expand Down Expand Up @@ -296,19 +304,42 @@ type Server struct {
DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter

// Messages queue
queue chan *message
// Number of workers, if set to zero - use spawn goroutines instead
Workers int
// Shutdown handling
lock sync.RWMutex
started bool
}

func (srv *Server) startWorkers() {
srv.queue = make(chan *message)
for i := 0; i < srv.Workers; i++ {
go func() {
for msg := range srv.queue {
srv.serve(msg)
}
}()
}
}

// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.started {
return &Error{err: "server already started"}
}

if srv.Handler == nil {
srv.Handler = DefaultServeMux
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this on needed?

scrolls down

Ah you're pulling it out from the serveX functions; sensible change; can you make that also a new PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serveUDP() and serveTCP() have:

	handler := srv.Handler
	if handler == nil {
		handler = DefaultServeMux
	}
        ...
        go srv.serve(s.RemoteAddr(), handler, ...)

Why do we need to set handler each time on serveUDP() or serveTCP() call and then pass handler into serve() if we can set srv.Handler only once and serve() can just use srv.Handler?


if srv.Workers > 0 {
srv.startWorkers()
}

addr := srv.Addr
if addr == "" {
addr = ":domain"
Expand Down Expand Up @@ -380,6 +411,15 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}

if srv.Handler == nil {
srv.Handler = DefaultServeMux
}

if srv.Workers > 0 {
srv.startWorkers()
}

pConn := srv.PacketConn
l := srv.Listener
if pConn != nil {
Expand Down Expand Up @@ -420,6 +460,10 @@ func (srv *Server) Shutdown() error {
srv.started = false
srv.lock.Unlock()

if srv.Workers > 0 {
close(srv.queue)
}

if srv.PacketConn != nil {
srv.PacketConn.Close()
}
Expand All @@ -439,7 +483,6 @@ func (srv *Server) getReadTimeout() time.Duration {
}

// serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()

Expand All @@ -452,10 +495,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
reader = srv.DecorateReader(reader)
}

handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
Expand All @@ -476,12 +515,16 @@ func (srv *Server) serveTCP(l net.Listener) error {
if err != nil {
continue
}
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
msg := &message{m, nil, nil, rw}
if srv.Workers > 0 {
srv.queue <- msg
} else {
go srv.serve(msg)
}
}
}

// serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()

Expand All @@ -494,10 +537,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}

handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
Expand All @@ -511,13 +550,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if err != nil {
continue
}
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
msg := &message{m, l, s, nil}
if srv.Workers > 0 {
srv.queue <- msg
} else {
go srv.serve(msg)
}
}
return nil
}

// Serve a new connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
func (srv *Server) serve(in *message) {
var a net.Addr

if in.s != nil {
a = in.s.RemoteAddr()
}
if in.t != nil {
a = in.t.RemoteAddr()
}

w := &response{tsigSecret: srv.TsigSecret, udp: in.u, tcp: in.t, remoteAddr: a, udpSession: in.s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand All @@ -532,7 +586,7 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
}
Redo:
req := new(Msg)
err := req.Unpack(m)
err := req.Unpack(in.m)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
Expand All @@ -550,12 +604,12 @@ Redo:
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
}
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigStatus = TsigVerify(in.m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
srv.Handler.ServeDNS(w, req) // Writes back to the client

Exit:
if w.tcp == nil {
Expand All @@ -570,15 +624,15 @@ Exit:
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
if in.u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
in.m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
Expand Down