Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNS forwarder #3024

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,50 @@ import (
)

type DNSForwarder struct {
ListenAddress string
TTL uint32
listenAddress string
ttl uint32
domains []string

dnsServer *dns.Server
mux *dns.ServeMux
}

func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
domains: domains,
}
}
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
mux := dns.NewServeMux()
mux.HandleFunc(".", f.handleDNSQuery)

for _, d := range f.domains {
mux.HandleFunc(d, f.handleDNSQuery)
}

dnsServer := &dns.Server{
Addr: f.ListenAddress,
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
return dnsServer.ListenAndServe()
}

func (f *DNSForwarder) UpdateDomains(domains []string) {
for _, d := range f.domains {
f.mux.HandleRemove(d)
}

for _, d := range domains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
f.domains = domains
}

func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
Expand All @@ -37,7 +61,7 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
}

func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Debugf("received DNS query for DNS forwarder: %v", query)
log.Tracef("received DNS query for DNS forwarder: %v", query)
if len(query.Question) == 0 {
return
}
Expand All @@ -49,8 +73,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {

ips, err := net.LookupIP(domain)
if err != nil {
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeServerFailure
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeRefused
_ = w.WriteMsg(resp)
return
}
Expand All @@ -66,7 +90,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.TTL,
Ttl: f.ttl,
},
}
respRecord = &rr
Expand All @@ -77,7 +101,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.TTL,
Ttl: f.ttl,
},
}
respRecord = &rr
Expand Down
76 changes: 47 additions & 29 deletions client/internal/dnsfwd/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,47 @@ package dnsfwd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"net"

"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"

nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)

const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
)

type Manager struct {
Firewall firewall.Manager
firewall firewall.Manager

fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}

dnsRules []firewall.Rule
service *DNSForwarder
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
}

func (m *Manager) Start() error {
func (m *Manager) Start(domains []string) error {
log.Infof("starting DNS forwarder")
if m.service != nil {
if m.dnsForwarder != nil {
return nil
}

if err := m.allowDNSFirewall(); err != nil {
return err
}

m.service = &DNSForwarder{
// todo listen only NetBird interface
ListenAddress: fmt.Sprintf(":%d", ListenPort),
TTL: 300,
}

m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
go func() {
if err := m.service.Listen(); err != nil {
if err := m.dnsForwarder.Listen(); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
Expand All @@ -46,43 +52,55 @@ func (m *Manager) Start() error {
return nil
}

func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}

m.dnsForwarder.UpdateDomains(domains)
}

func (m *Manager) Stop(ctx context.Context) error {
if m.service == nil {
if m.dnsForwarder == nil {
return nil
}

err := m.service.Close(ctx)
m.service = nil
return err
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}

if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}

m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
}

func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.dnsRules = dnsRules
h.fwRules = dnsRules

return nil
}

func (h *Manager) dropDNSFirewall() error {
if len(h.dnsRules) == 0 {
return nil
}

for _, rule := range h.dnsRules {
if err := h.Firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete DNS router rules, err: %v", err)
return err
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}

h.dnsRules = nil
return nil
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}
59 changes: 32 additions & 27 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (
"sync/atomic"
"time"

"github.com/netbirdio/netbird/client/internal/dnsfwd"

"github.com/pion/ice/v3"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
Expand All @@ -31,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
Expand Down Expand Up @@ -789,7 +788,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
}

func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {

// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
Expand All @@ -809,31 +807,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap)
}

isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
routedDomains, routes := toRoutes(networkMap.GetRoutes())

if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}

if isDNSRouter {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = &dnsfwd.Manager{
Firewall: e.firewall,
}

if err := e.dnsForwardMgr.Start(); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
}
}
} else {
if e.dnsForwardMgr != nil {
// todo: review context
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
e.updateDNSForwarder(routedDomains)

log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))

Expand Down Expand Up @@ -895,12 +875,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}

func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}

var isDNSRouter bool
var dnsRoutes []string
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
Expand All @@ -911,7 +891,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
continue
}
}
isDNSRouter = true
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)

convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Expand All @@ -926,7 +906,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
}
routes = append(routes, convertedRoute)
}
return isDNSRouter, routes
return dnsRoutes, routes
}

func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
Expand Down Expand Up @@ -1574,6 +1554,31 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil
}

func (e *Engine) updateDNSForwarder(domains []string) {
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)

if err := e.dnsForwardMgr.Start(domains); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains)
}
} else {
if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
}

// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
Expand Down
Loading