diff --git a/main.go b/main.go index d7c8e34..076bbac 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "crypto/tls" + "errors" "flag" "log" "net" @@ -16,11 +17,6 @@ import ( "github.com/miekg/dns" ) -type cachedResponse struct { - ResponseTime time.Time - Resp *dns.Msg -} - type debounceInfo struct { tm time.Time cnt int @@ -28,8 +24,9 @@ type debounceInfo struct { type handler struct { dns.Handler - Upstream string - Client *dns.Client + ClientTimeout time.Duration + clients map[string]*dns.Client + servers map[string]string ptrMap map[string]string lastResultSent map[string]debounceInfo lastResultSentLock sync.Mutex @@ -76,6 +73,44 @@ func (h *handler) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { return nil } +type errNoUpstream error + +func (h *handler) SetUpstream(upstream []string) { + h.clients = map[string]*dns.Client{} + h.servers = map[string]string{} + for _, up := range upstream { + domain := "" + proto := "udp" + if eqIdx := strings.Index(up, "="); eqIdx >= 0 { + domain, up = up[:eqIdx], up[eqIdx+1:] + } + if protoIdx := strings.Index(up, "://"); protoIdx >= 0 { + proto, up = up[:protoIdx], up[protoIdx+3:] + } + h.clients[domain] = &dns.Client{Net: proto, Timeout: h.ClientTimeout} + h.servers[domain] = up + } +} + +func (h *handler) upstreamExchange(r *dns.Msg) (*dns.Msg, time.Duration, error) { + qName := strings.TrimSuffix(r.Question[0].Name, ".") + for dotIdx, domain := 0, qName; dotIdx >= 0; dotIdx = strings.Index(domain, ".") { + domain = domain[dotIdx:] + if cl, have := h.clients[domain]; have { + log.Printf("[%02x] will use %s for %s", r.Id, h.servers[domain], domain) + return cl.Exchange(r, h.servers[domain]) + } + if dotIdx > 0 { + domain = domain[1:] + } + } + + if cl, have := h.clients[""]; have { + return cl.Exchange(r, h.servers[""]) + } + return nil, 0, errNoUpstream(errors.New("no upstream server set")) +} + func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if !h.checkNoDoS(w) { log.Printf("Dropping, DoS check failed to %s", w.RemoteAddr()) @@ -110,19 +145,17 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { log.Printf("PTR not in cache: %#v", q.Name) } if !r.RecursionDesired { - log.Printf("Client %s doesn't want recursion", w.RemoteAddr()) + log.Printf("[%02x] Client %s doesn't want recursion", r.Id, w.RemoteAddr()) _ = h.writeMsg(w, &dns.Msg{ MsgHdr: dns.MsgHdr{Id: r.MsgHdr.Id, Response: true, Rcode: dns.RcodeServerFailure}}) return } - if h.Client == nil { - h.Client = &dns.Client{Net: "udp"} - } - resp, rtt, err := h.Client.Exchange(r, h.Upstream) + resp, rtt, err := h.upstreamExchange(r) if err != nil { - log.Printf("Error getting response from %#v: %s", h.Upstream, err) - if _, ok := err.(*net.OpError); ok { - _ = h.writeMsg(w, + log.Printf("[%02x] Error getting response: %s", r.Id, err) + switch err.(type) { + case *net.OpError, errNoUpstream: + h.writeMsg(w, &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: r.MsgHdr.Id, Response: true, @@ -134,15 +167,15 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } if !h.IsSilent { - log.Printf("Got response from %#v (rtt=%s):\n%s", h.Upstream, rtt, resp) + log.Printf("[%02x] Got response (rtt=%s)\n%s", r.Id, rtt, resp) } for _, answ := range resp.Answer { addrString := "" - switch answ.(type) { + switch a := answ.(type) { case *dns.AAAA: - addrString = answ.(*dns.AAAA).AAAA.String() + addrString = a.AAAA.String() case *dns.A: - addrString = answ.(*dns.A).A.String() + addrString = a.A.String() default: continue } @@ -207,39 +240,40 @@ const ( tlsListenHelp = "-tlslisten []: with port>1024" ) +type ArrayFlag []string + +func (f *ArrayFlag) String() string { + return strings.Join(*f, ", ") +} + +func (f *ArrayFlag) Set(value string) error { + *f = append(*f, value) + return nil +} + func main() { - upstreamServerFlag := flag.String("upstream", "1.1.1.1:53", "upstream DNS server (tcp-tls:// prefix for DoT)") + h := &handler{} + + var upstream ArrayFlag + flag.Var(&upstream, "upstream", "upstream DNS server (tcp-tls:// prefix for DoT), multi-val, can prefix with .domain=") listenAddrFlag := flag.String("listen", ":53", "listen address") tlsListenFlag := flag.String("tlslisten", ":853", "TCP-TLS listener address") certFlag := flag.String("cert", "", "TCP-TLS listener certificate (required for tls listener)") keyFlag := flag.String("key", "", "TCP-TLS certificate key (default same as -cert value)") debounceDelayFlag := flag.String("debounce", "200ms", "Required time duration between UDP replies to single IP to prevent DoS") - debounceCountFlag := flag.Int("count", 100, + flag.IntVar(&h.DebounceCount, "count", 100, "Count of replies allowed before debounce delay is applied") storeFlag := flag.String("store", "", "Store PTR data to specified file") chrootFlag := flag.String("chroot", DefaultChroot, "chroot to directory after start") - silentFlag := flag.Bool("silent", false, "Don't report normal data") - clientTimeoutFlag := flag.String("ctmout", "", "Client timeout for upstream queries") + flag.BoolVar(&h.IsSilent, "silent", false, "Don't report normal data") + flag.DurationVar(&h.ClientTimeout, "ctmout", 0, "Client timeout for upstream queries") flag.Parse() - h := &handler{Upstream: *upstreamServerFlag, DebounceCount: *debounceCountFlag, IsSilent: *silentFlag} - if *clientTimeoutFlag != "" { - cltmout, err := time.ParseDuration(*clientTimeoutFlag) - if err != nil { - log.Fatal("Cannot parse client timeout: ", err) - } - h.Client = &dns.Client{Net: "udp", Timeout: cltmout} - } - - if strings.Contains(*upstreamServerFlag, "://") { - idx := strings.Index(*upstreamServerFlag, "://") - if h.Client == nil { - h.Client = &dns.Client{} - } - h.Client.Net = (*upstreamServerFlag)[:idx] - h.Upstream = (*upstreamServerFlag)[idx+len("://"):] + if len(upstream) == 0 { + upstream.Set("tcp-tls://1.1.1.1:853") } + h.SetUpstream(upstream) var tlsServer *dns.Server var srv *dns.Server @@ -293,7 +327,7 @@ func main() { signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) if *listenAddrFlag != "" { - srv = &dns.Server{Addr: *listenAddrFlag, Net: "udp", Handler: h} + srv = &dns.Server{Addr: *listenAddrFlag, Net: "udp", Handler: h, ReusePort: true} go func() { if !h.IsSilent { log.Printf("ListenAndServe on %s", *listenAddrFlag)