Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Persist route selection #2810

Merged
merged 8 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions client/firewall/iptables/rulestore_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
return err
}
s.ips = temp.IPs

if temp.IPs == nil {
pappz marked this conversation as resolved.
Show resolved Hide resolved
temp.IPs = make(map[string]struct{})
}

return nil
}

Expand Down Expand Up @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
return err
}
s.ipsets = temp.IPSets

pappz marked this conversation as resolved.
Show resolved Hide resolved
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}

return nil
}
13 changes: 11 additions & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,17 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer

e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
e.routeManager = routemanager.NewManager(
e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
Expand Down
5 changes: 4 additions & 1 deletion client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
nil)

wgIface := &iface.MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error {
return nil
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
Expand Down
38 changes: 34 additions & 4 deletions client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (

// Manager is a route manager interface
type Manager interface {
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
Expand All @@ -59,6 +59,7 @@ type DefaultManager struct {
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
stateManager *statemanager.Manager
}

func NewManager(
Expand All @@ -69,6 +70,7 @@ func NewManager(
statusRecorder *peer.Status,
relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
Expand All @@ -80,12 +82,12 @@ func NewManager(
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: notifier,
stateManager: stateManager,
}

dm.routeRefCounter = refcounter.New(
Expand Down Expand Up @@ -121,7 +123,7 @@ func NewManager(
}

// Init sets up the routing
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
Expand All @@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook

ips := resolveURLsToIPs(initialAddresses)

beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}

m.routeSelector = m.initSelector()

log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
}

func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
var state *SelectorState
m.stateManager.RegisterState(state)

// restore selector state if it exists
if err := m.stateManager.LoadState(state); err != nil {
log.Warnf("failed to load state: %v", err)
return routeselector.NewRouteSelector()
}

if state := m.stateManager.GetState(state); state != nil {
if selector, ok := state.(*SelectorState); ok {
return (*routeselector.RouteSelector)(selector)
}

log.Warnf("failed to convert state with type %T to SelectorState", state)
}

return routeselector.NewRouteSelector()
}

func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
Expand Down Expand Up @@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}

if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
Expand Down
4 changes: 2 additions & 2 deletions client/internal/routemanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {

statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)

_, _, err = routeManager.Init(nil)
_, _, err = routeManager.Init()

require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil)
Expand Down
2 changes: 1 addition & 1 deletion client/internal/routemanager/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager)
}

func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
}

Expand Down
13 changes: 13 additions & 0 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
}

// LoadData loads the data from the existing counter
// The passed counter should not be used any longer after calling this function.
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.mu.Lock()
defer rm.mu.Unlock()
existingCounter.mu.Lock()
defer existingCounter.mu.Unlock()

rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
Expand Down Expand Up @@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {

// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.mu.Lock()
defer rm.mu.Unlock()

var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
Expand All @@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap

if temp.RefCountMap == nil {
temp.RefCountMap = map[Key]Ref[O]{}
}
if temp.IDMap == nil {
temp.IDMap = map[string][]Key{}
}

return nil
}

Expand Down
19 changes: 19 additions & 0 deletions client/internal/routemanager/state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package routemanager

import (
"github.com/netbirdio/netbird/client/internal/routeselector"
)

type SelectorState routeselector.RouteSelector

func (s *SelectorState) Name() string {
return "routeselector_state"
}

func (s *SelectorState) MarshalJSON() ([]byte, error) {
return (*routeselector.RouteSelector)(s).MarshalJSON()
}

func (s *SelectorState) UnmarshalJSON(data []byte) error {
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
}
67 changes: 67 additions & 0 deletions client/internal/routeselector/routeselector.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package routeselector

import (
"encoding/json"
"fmt"
"slices"
"sync"

"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
Expand All @@ -12,6 +14,7 @@ import (
)

type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{}
selectAll bool
}
Expand All @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {

// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()

if !appendRoute {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
Expand All @@ -46,13 +52,19 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al

// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()

rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{}
}

// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()

if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
Expand All @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.

// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()

rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
}

// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()

if rs.selectAll {
return true
}
Expand All @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {

// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()

if rs.selectAll {
return maps.Clone(routes)
}
Expand All @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
return filtered
}

// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()

return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}{
SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes,
})
}

// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()

// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
return nil
}

var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}

if err := json.Unmarshal(data, &temp); err != nil {
return err
}

rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll

if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}

return nil
}
Loading
Loading