From 5100be4d538b6a35869f77b2a86bf89f5dcd1590 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Sat, 8 Sep 2018 20:37:03 +0930 Subject: [PATCH] Add SO_REUSEPORT implementation Fixes #654 --- server.go | 27 +++++++++++++------------- server_go111.go | 43 ++++++++++++++++++++++++++++++++++++++++++ server_go_not111.go | 23 +++++++++++++++++++++++ server_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 13 deletions(-) create mode 100644 server_go111.go create mode 100644 server_go_not111.go diff --git a/server.go b/server.go index b1981b8e4..2ecf1820e 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "errors" "io" "net" "strings" @@ -313,6 +314,9 @@ type Server struct { DecorateWriter DecorateWriter // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1). MaxTCPQueries int + // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. + // It is only supported on go1.11+ and when using ListenAndServe. + Reuseport bool // UDP packet or TCP connection queue queue chan *response @@ -401,11 +405,7 @@ func (srv *Server) ListenAndServe() error { defer close(srv.queue) switch srv.Net { case "tcp", "tcp4", "tcp6": - a, err := net.ResolveTCPAddr(srv.Net, addr) - if err != nil { - return err - } - l, err := net.ListenTCP(srv.Net, a) + l, err := listenTCP(srv.Net, addr, srv.Reuseport) if err != nil { return err } @@ -414,31 +414,32 @@ func (srv *Server) ListenAndServe() error { unlock() return srv.serveTCP(l) case "tcp-tls", "tcp4-tls", "tcp6-tls": + if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { + return errors.New("dns: neither Certificates nor GetCertificate set in Config") + } network := strings.TrimSuffix(srv.Net, "-tls") - l, err := tls.Listen(network, addr, srv.TLSConfig) + l, err := listenTCP(network, addr, srv.Reuseport) if err != nil { return err } + l = tls.NewListener(l, srv.TLSConfig) srv.Listener = l srv.started = true unlock() return srv.serveTCP(l) case "udp", "udp4", "udp6": - a, err := net.ResolveUDPAddr(srv.Net, addr) - if err != nil { - return err - } - l, err := net.ListenUDP(srv.Net, a) + l, err := listenUDP(srv.Net, addr, srv.Reuseport) if err != nil { return err } - if e := setUDPSocketOptions(l); e != nil { + u := l.(*net.UDPConn) + if e := setUDPSocketOptions(u); e != nil { return e } srv.PacketConn = l srv.started = true unlock() - return srv.serveUDP(l) + return srv.serveUDP(u) } return &Error{err: "bad network"} } diff --git a/server_go111.go b/server_go111.go new file mode 100644 index 000000000..f51cf92dc --- /dev/null +++ b/server_go111.go @@ -0,0 +1,43 @@ +// +build go1.11,!windows + +package dns + +import ( + "context" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +const supportsReuseport = true + +func reuseportControl(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) + if err != nil { + return err + } + + return opErr +} + +func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { + var lc net.ListenConfig + if reuseport { + lc.Control = reuseportControl + } + + return lc.Listen(context.Background(), network, addr) +} + +func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { + var lc net.ListenConfig + if reuseport { + lc.Control = reuseportControl + } + + return lc.ListenPacket(context.Background(), network, addr) +} diff --git a/server_go_not111.go b/server_go_not111.go new file mode 100644 index 000000000..a2834213e --- /dev/null +++ b/server_go_not111.go @@ -0,0 +1,23 @@ +// +build !go1.11 windows + +package dns + +import "net" + +const supportsReuseport = false + +func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { + if reuseport { + // TODO(tmthrgd): return an error? + } + + return net.Listen(network, addr) +} + +func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { + if reuseport { + // TODO(tmthrgd): return an error? + } + + return net.ListenPacket(network, addr) +} diff --git a/server_test.go b/server_test.go index 6fc9e88c1..cce0997fd 100644 --- a/server_test.go +++ b/server_test.go @@ -687,6 +687,52 @@ func TestServerStartStopRace(t *testing.T) { } } +func TestServerReuseport(t *testing.T) { + if !supportsReuseport { + t.Skip("reuseport is not supported") + } + + startServer := func(addr string) (*Server, chan error) { + wait := make(chan struct{}) + srv := &Server{ + Net: "udp", + Addr: addr, + NotifyStartedFunc: func() { close(wait) }, + Reuseport: true, + } + + fin := make(chan error, 1) + go func() { + fin <- srv.ListenAndServe() + }() + + select { + case <-wait: + case err := <-fin: + t.Fatalf("failed to start server: %v", err) + } + + return srv, fin + } + + srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel + srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String()) + + if err := srv1.Shutdown(); err != nil { + t.Fatalf("failed to shutdown first server: %v", err) + } + if err := srv2.Shutdown(); err != nil { + t.Fatalf("failed to shutdown second server: %v", err) + } + + if err := <-fin1; err != nil { + t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) + } + if err := <-fin2; err != nil { + t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) + } +} + type ExampleFrameLengthWriter struct { Writer }