Skip to content

Commit

Permalink
add multiple upstream support
Browse files Browse the repository at this point in the history
  • Loading branch information
korc committed May 14, 2023
1 parent 1b97123 commit 4ce44d4
Showing 1 changed file with 74 additions and 40 deletions.
114 changes: 74 additions & 40 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"crypto/tls"
"errors"
"flag"
"log"
"net"
Expand All @@ -16,20 +17,16 @@ import (
"github.com/miekg/dns"
)

type cachedResponse struct {
ResponseTime time.Time
Resp *dns.Msg
}

type debounceInfo struct {
tm time.Time
cnt int
}

type handler struct {
dns.Handler
Upstream string
Client *dns.Client
ClientTimeout time.Duration
clients map[string]*dns.Client
servers map[string]string
ptrMap map[string]string
lastResultSent map[string]debounceInfo
lastResultSentLock sync.Mutex
Expand Down Expand Up @@ -76,6 +73,44 @@ func (h *handler) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil
}

type errNoUpstream error

func (h *handler) SetUpstream(upstream []string) {
h.clients = map[string]*dns.Client{}
h.servers = map[string]string{}
for _, up := range upstream {
domain := ""
proto := "udp"
if eqIdx := strings.Index(up, "="); eqIdx >= 0 {
domain, up = up[:eqIdx], up[eqIdx+1:]
}
if protoIdx := strings.Index(up, "://"); protoIdx >= 0 {
proto, up = up[:protoIdx], up[protoIdx+3:]
}
h.clients[domain] = &dns.Client{Net: proto, Timeout: h.ClientTimeout}
h.servers[domain] = up
}
}

func (h *handler) upstreamExchange(r *dns.Msg) (*dns.Msg, time.Duration, error) {
qName := strings.TrimSuffix(r.Question[0].Name, ".")
for dotIdx, domain := 0, qName; dotIdx >= 0; dotIdx = strings.Index(domain, ".") {
domain = domain[dotIdx:]
if cl, have := h.clients[domain]; have {
log.Printf("[%02x] will use %s for %s", r.Id, h.servers[domain], domain)
return cl.Exchange(r, h.servers[domain])
}
if dotIdx > 0 {
domain = domain[1:]
}
}

if cl, have := h.clients[""]; have {
return cl.Exchange(r, h.servers[""])
}
return nil, 0, errNoUpstream(errors.New("no upstream server set"))
}

func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if !h.checkNoDoS(w) {
log.Printf("Dropping, DoS check failed to %s", w.RemoteAddr())
Expand Down Expand Up @@ -110,19 +145,17 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Printf("PTR not in cache: %#v", q.Name)
}
if !r.RecursionDesired {
log.Printf("Client %s doesn't want recursion", w.RemoteAddr())
log.Printf("[%02x] Client %s doesn't want recursion", r.Id, w.RemoteAddr())
_ = h.writeMsg(w, &dns.Msg{
MsgHdr: dns.MsgHdr{Id: r.MsgHdr.Id, Response: true, Rcode: dns.RcodeServerFailure}})
return
}
if h.Client == nil {
h.Client = &dns.Client{Net: "udp"}
}
resp, rtt, err := h.Client.Exchange(r, h.Upstream)
resp, rtt, err := h.upstreamExchange(r)
if err != nil {
log.Printf("Error getting response from %#v: %s", h.Upstream, err)
if _, ok := err.(*net.OpError); ok {
_ = h.writeMsg(w,
log.Printf("[%02x] Error getting response: %s", r.Id, err)
switch err.(type) {
case *net.OpError, errNoUpstream:
h.writeMsg(w,
&dns.Msg{
MsgHdr: dns.MsgHdr{
Id: r.MsgHdr.Id, Response: true,
Expand All @@ -134,15 +167,15 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if !h.IsSilent {
log.Printf("Got response from %#v (rtt=%s):\n%s", h.Upstream, rtt, resp)
log.Printf("[%02x] Got response (rtt=%s)\n%s", r.Id, rtt, resp)
}
for _, answ := range resp.Answer {
addrString := ""
switch answ.(type) {
switch a := answ.(type) {
case *dns.AAAA:
addrString = answ.(*dns.AAAA).AAAA.String()
addrString = a.AAAA.String()
case *dns.A:
addrString = answ.(*dns.A).A.String()
addrString = a.A.String()
default:
continue
}
Expand Down Expand Up @@ -207,39 +240,40 @@ const (
tlsListenHelp = "-tlslisten [<ip>]:<port> with port>1024"
)

type ArrayFlag []string

func (f *ArrayFlag) String() string {
return strings.Join(*f, ", ")
}

func (f *ArrayFlag) Set(value string) error {
*f = append(*f, value)
return nil
}

func main() {
upstreamServerFlag := flag.String("upstream", "1.1.1.1:53", "upstream DNS server (tcp-tls:// prefix for DoT)")
h := &handler{}

var upstream ArrayFlag
flag.Var(&upstream, "upstream", "upstream DNS server (tcp-tls:// prefix for DoT), multi-val, can prefix with .domain=")
listenAddrFlag := flag.String("listen", ":53", "listen address")
tlsListenFlag := flag.String("tlslisten", ":853", "TCP-TLS listener address")
certFlag := flag.String("cert", "", "TCP-TLS listener certificate (required for tls listener)")
keyFlag := flag.String("key", "", "TCP-TLS certificate key (default same as -cert value)")
debounceDelayFlag := flag.String("debounce", "200ms",
"Required time duration between UDP replies to single IP to prevent DoS")
debounceCountFlag := flag.Int("count", 100,
flag.IntVar(&h.DebounceCount, "count", 100,
"Count of replies allowed before debounce delay is applied")
storeFlag := flag.String("store", "", "Store PTR data to specified file")
chrootFlag := flag.String("chroot", DefaultChroot, "chroot to directory after start")
silentFlag := flag.Bool("silent", false, "Don't report normal data")
clientTimeoutFlag := flag.String("ctmout", "", "Client timeout for upstream queries")
flag.BoolVar(&h.IsSilent, "silent", false, "Don't report normal data")
flag.DurationVar(&h.ClientTimeout, "ctmout", 0, "Client timeout for upstream queries")
flag.Parse()

h := &handler{Upstream: *upstreamServerFlag, DebounceCount: *debounceCountFlag, IsSilent: *silentFlag}
if *clientTimeoutFlag != "" {
cltmout, err := time.ParseDuration(*clientTimeoutFlag)
if err != nil {
log.Fatal("Cannot parse client timeout: ", err)
}
h.Client = &dns.Client{Net: "udp", Timeout: cltmout}
}

if strings.Contains(*upstreamServerFlag, "://") {
idx := strings.Index(*upstreamServerFlag, "://")
if h.Client == nil {
h.Client = &dns.Client{}
}
h.Client.Net = (*upstreamServerFlag)[:idx]
h.Upstream = (*upstreamServerFlag)[idx+len("://"):]
if len(upstream) == 0 {
upstream.Set("tcp-tls://1.1.1.1:853")
}
h.SetUpstream(upstream)

var tlsServer *dns.Server
var srv *dns.Server
Expand Down Expand Up @@ -293,7 +327,7 @@ func main() {
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)

if *listenAddrFlag != "" {
srv = &dns.Server{Addr: *listenAddrFlag, Net: "udp", Handler: h}
srv = &dns.Server{Addr: *listenAddrFlag, Net: "udp", Handler: h, ReusePort: true}
go func() {
if !h.IsSilent {
log.Printf("ListenAndServe on %s", *listenAddrFlag)
Expand Down

0 comments on commit 4ce44d4

Please sign in to comment.