From 2fa143306320b1469c9e3249c20dced5e1bbe5f6 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 14 Dec 2024 16:46:49 +0100 Subject: [PATCH] Use DNS route feature flag (#3048) Co-authored-by: Viktor Liu --- client/internal/engine.go | 9 ++-- client/internal/routemanager/client.go | 45 ++++++++++++++------ client/internal/routemanager/manager.go | 26 ++++++++--- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 2 +- 5 files changed, 61 insertions(+), 25 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index caf39a34f4f..b6fae5b2bfe 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -802,14 +802,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } + var dnsRouteFeatureFlag bool + if networkMap.PeerConfig != nil { + dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled + } routedDomains, routes := toRoutes(networkMap.GetRoutes()) - if err := e.routeManager.UpdateRoutes(serial, routes); err != nil { + if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - // todo: useRoutingPeerDnsResolutionEnabled from network map proto - e.updateDNSForwarder(true, routedDomains) + e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b7fc5b15d00..6265736a157 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -21,7 +21,11 @@ import ( "github.com/netbirdio/netbird/route" ) -const useNewDNSRoute = true +const ( + handlerTypeDynamic = iota + handlerTypeDomain + handlerTypeStatic +) type routerPeerStatus struct { connected bool @@ -67,6 +71,7 @@ func newClientNetworkWatcher( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsServer nbdns.Server, peerStore *peerstore.Store, + useNewDNSRoute bool, ) *clientNetwork { ctx, cancel := context.WithCancel(ctx) @@ -88,6 +93,7 @@ func newClientNetworkWatcher( wgInterface, dnsServer, peerStore, + useNewDNSRoute, ), } return client @@ -400,18 +406,19 @@ func handlerFromRoute( wgInterface iface.IWGIface, dnsServer nbdns.Server, peerStore *peerstore.Store, + useNewDNSRoute bool, ) RouteHandler { - if rt.IsDynamic() { - if useNewDNSRoute { - return dnsinterceptor.New( - rt, - routeRefCounter, - allowedIPsRefCounter, - statusRecorder, - dnsServer, - peerStore, - ) - } + switch handlerType(rt, useNewDNSRoute) { + case handlerTypeDomain: + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerStore, + ) + case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(wgInterface) return dynamic.NewRoute( rt, @@ -422,6 +429,18 @@ func handlerFromRoute( wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), ) + default: + return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + } +} + +func handlerType(rt *route.Route, useNewDNSRoute bool) int { + if !rt.IsDynamic() { + return handlerTypeStatic + } + + if useNewDNSRoute { + return handlerTypeDomain } - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + return handlerTypeStatic } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 30899bc1d06..389e97e2dcc 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,7 +36,7 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap @@ -66,9 +66,10 @@ type DefaultManager struct { dnsRouteInterval time.Duration stateManager *statemanager.Manager // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - dnsServer dns.Server - peerStore *peerstore.Store + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store + useNewDNSRoute bool } func NewManager( @@ -227,7 +228,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") @@ -237,6 +238,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.mux.Lock() defer m.mux.Unlock() + m.useNewDNSRoute = useNewDNSRoute newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) @@ -318,6 +320,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.allowedIPsRefCounter, m.dnsServer, m.peerStore, + m.useNewDNSRoute, ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() @@ -347,7 +350,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore) + clientNetworkWatcher = newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 71b951593a6..4b7c984e5a0 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) require.NoError(t, err, "should update routes with init routes") } - _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 0219b17c89f..64fdffceb3e 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -32,7 +32,7 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) }