diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index e6a17519b60..9a6d9720794 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -164,9 +164,9 @@ func (a *Anonymizer) AnonymizeString(str string) string { return str } -// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. +// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes. func (a *Anonymizer) AnonymizeSchemeURI(text string) string { - re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) + re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`) return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index 445f3dce32d..a3aae1ee982 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -158,8 +158,16 @@ func TestAnonymizeSchemeURI(t *testing.T) { expect string }{ {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, + {"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, + {"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`}, + {"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`}, + {"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, + {"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`}, + {"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`}, + {"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`}, + {"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`}, } for _, tc := range tests { diff --git a/client/cmd/pprof.go b/client/cmd/pprof.go new file mode 100644 index 00000000000..37efd35f0cd --- /dev/null +++ b/client/cmd/pprof.go @@ -0,0 +1,33 @@ +//go:build pprof +// +build pprof + +package cmd + +import ( + "net/http" + _ "net/http/pprof" + "os" + + log "github.com/sirupsen/logrus" +) + +func init() { + addr := pprofAddr() + go pprof(addr) +} + +func pprofAddr() string { + listenAddr := os.Getenv("NB_PPROF_ADDR") + if listenAddr == "" { + return "localhost:6969" + } + + return listenAddr +} + +func pprof(listenAddr string) { + log.Infof("listening pprof on: %s\n", listenAddr) + if err := http.ListenAndServe(listenAddr, nil); err != nil { + log.Fatalf("Failed to start pprof: %v", err) + } +} diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index bfd08bee27d..004c512a4b4 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error { return err } s.ips = temp.IPs + + if temp.IPs == nil { + temp.IPs = make(map[string]struct{}) + } + return nil } @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error { return err } s.ipsets = temp.IPSets + + if temp.IPSets == nil { + temp.IPSets = make(map[string]*ipList) + } + return nil } diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 12f7a812970..00a91f0ecad 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} case addr.IP.To16() != nil: networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + default: params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 64410114cee..fc9620d801e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -354,8 +354,17 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) + e.routeManager = routemanager.NewManager( + e.ctx, + e.config.WgPrivateKey.PublicKey().String(), + e.config.DNSRouteInterval, + e.wgInterface, + e.statusRecorder, + e.relayManager, + initialRoutes, + e.stateManager, + ) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b6c6186ea83..b58c1f7e93a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { nil) wgIface := &iface.MockWGIface{ + NameFunc: func() string { return "utun102" }, RemovePeerFunc: func(peerKey string) error { return nil }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + _, _, err = engine.routeManager.Init() + require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 81c456db747..a8de2fccb73 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -83,7 +83,6 @@ type Conn struct { signaler *Signaler relayManager *relayClient.Manager allowedIP net.IP - allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -111,7 +110,7 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { - allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) + allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -129,7 +128,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, relayManager: relayManager, allowedIP: allowedIP, - allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), } @@ -594,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0a1c7dc56b8..f1c4ae5ef75 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -32,7 +32,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -59,6 +59,7 @@ type DefaultManager struct { routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration + stateManager *statemanager.Manager } func NewManager( @@ -69,6 +70,7 @@ func NewManager( statusRecorder *peer.Status, relayMgr *relayClient.Manager, initialRoutes []*route.Route, + stateManager *statemanager.Manager, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -80,12 +82,12 @@ func NewManager( dnsRouteInterval: dnsRouteInterval, clientNetworks: make(map[route.HAUniqueID]*clientNetwork), relayMgr: relayMgr, - routeSelector: routeselector.NewRouteSelector(), sysOps: sysOps, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, notifier: notifier, + stateManager: stateManager, } dm.routeRefCounter = refcounter.New( @@ -121,7 +123,7 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } @@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } + + m.routeSelector = m.initSelector() + log.Info("Routing setup complete") return beforePeerHook, afterPeerHook, nil } +func (m *DefaultManager) initSelector() *routeselector.RouteSelector { + var state *SelectorState + m.stateManager.RegisterState(state) + + // restore selector state if it exists + if err := m.stateManager.LoadState(state); err != nil { + log.Warnf("failed to load state: %v", err) + return routeselector.NewRouteSelector() + } + + if state := m.stateManager.GetState(state); state != nil { + if selector, ok := state.(*SelectorState); ok { + return (*routeselector.RouteSelector)(selector) + } + + log.Warnf("failed to convert state with type %T to SelectorState", state) + } + + return routeselector.NewRouteSelector() +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) } + + if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { + log.Errorf("failed to update state: %v", err) + } } // stopObsoleteClients stops the client network watcher for the networks that are not in the new list diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index e669bc44a08..07dac21b819 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) - _, _, err = routeManager.Init(nil) + _, _, err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 503185f0311..556a6235138 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -21,7 +21,7 @@ type MockManager struct { StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f2f0a169df0..27a724f5062 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } // LoadData loads the data from the existing counter +// The passed counter should not be used any longer after calling this function. func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { rm.mu.Lock() defer rm.mu.Unlock() + existingCounter.mu.Lock() + defer existingCounter.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the json.Unmarshaler interface for Counter. func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + rm.mu.Lock() + defer rm.mu.Unlock() + var temp struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` @@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { rm.refCountMap = temp.RefCountMap rm.idMap = temp.IDMap + if temp.RefCountMap == nil { + temp.RefCountMap = map[Key]Ref[O]{} + } + if temp.IDMap == nil { + temp.IDMap = map[string][]Key{} + } + return nil } diff --git a/client/internal/routemanager/state.go b/client/internal/routemanager/state.go new file mode 100644 index 00000000000..a45c32b506d --- /dev/null +++ b/client/internal/routemanager/state.go @@ -0,0 +1,19 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/client/internal/routeselector" +) + +type SelectorState routeselector.RouteSelector + +func (s *SelectorState) Name() string { + return "routeselector_state" +} + +func (s *SelectorState) MarshalJSON() ([]byte, error) { + return (*routeselector.RouteSelector)(s).MarshalJSON() +} + +func (s *SelectorState) UnmarshalJSON(data []byte) error { + return (*routeselector.RouteSelector)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 00128a27b03..2874604fdd3 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -1,8 +1,10 @@ package routeselector import ( + "encoding/json" "fmt" "slices" + "sync" "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" @@ -12,6 +14,7 @@ import ( ) type RouteSelector struct { + mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool } @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector { // SelectRoutes updates the selected routes based on the provided route IDs. func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if !appendRoute { rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al // SelectAllRoutes sets the selector to select all routes. func (rs *RouteSelector) SelectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() { // DeselectRoutes removes specific routes from the selection. // If the selector is in "select all" mode, it will transition to "select specific" mode. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.selectAll { rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. // DeselectAllRoutes deselects all routes, effectively disabling route selection. func (rs *RouteSelector) DeselectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} } // IsSelected checks if a specific route is selected. func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return true } @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { // FilterSelected removes unselected routes from the provided map. func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return maps.Clone(routes) } @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { } return filtered } + +// MarshalJSON implements the json.Marshaler interface +func (rs *RouteSelector) MarshalJSON() ([]byte, error) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + return json.Marshal(struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + }{ + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +// If the JSON is empty or null, it will initialize like a NewRouteSelector. +func (rs *RouteSelector) UnmarshalJSON(data []byte) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Check for null or empty JSON + if len(data) == 0 || string(data) == "null" { + rs.selectedRoutes = map[route.NetID]struct{}{} + rs.selectAll = true + return nil + } + + var temp struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + rs.selectedRoutes = temp.SelectedRoutes + rs.selectAll = temp.SelectAll + + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } + + return nil +} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index da6dd022fc2..aae73b79f5e 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -22,9 +22,28 @@ import ( // State interface defines the methods that all state types must implement type State interface { Name() string +} + +// CleanableState interface extends State with cleanup capability +type CleanableState interface { + State Cleanup() error } +// RawState wraps raw JSON data for unregistered states +type RawState struct { + data json.RawMessage +} + +func (r *RawState) Name() string { + return "" // This is a placeholder implementation +} + +// MarshalJSON implements json.Marshaler to preserve the original JSON +func (r *RawState) MarshalJSON() ([]byte, error) { + return r.data, nil +} + // Manager handles the persistence and management of various states type Manager struct { mu sync.Mutex @@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } -// loadState loads the existing state from the state file -func (m *Manager) loadState() error { +// loadStateFile reads and unmarshals the state file into a map of raw JSON messages +func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { log.Debug("state file does not exist") - return nil + return nil, nil // nolint:nilnil } - return fmt.Errorf("read state file: %w", err) + return nil, fmt.Errorf("read state file: %w", err) } var rawStates map[string]json.RawMessage @@ -228,37 +247,69 @@ func (m *Manager) loadState() error { } else { log.Info("State file deleted") } - return fmt.Errorf("unmarshal states: %w", err) + return nil, fmt.Errorf("unmarshal states: %w", err) } - var merr *multierror.Error + return rawStates, nil +} - for name, rawState := range rawStates { - stateType, ok := m.stateTypes[name] - if !ok { - merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) - continue - } +// loadSingleRawState unmarshals a raw state into a concrete state object +func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { + stateType, ok := m.stateTypes[name] + if !ok { + return nil, fmt.Errorf("state %s not registered", name) + } - if string(rawState) == "null" { - continue - } + if string(rawState) == "null" { + return nil, nil //nolint:nilnil + } - statePtr := reflect.New(stateType).Interface().(State) - if err := json.Unmarshal(rawState, statePtr); err != nil { - merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) - continue - } + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + return nil, fmt.Errorf("unmarshal state %s: %w", name, err) + } + + return statePtr, nil +} + +// LoadState loads a specific state from the state file +func (m *Manager) LoadState(state State) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() - m.states[name] = statePtr + rawStates, err := m.loadStateFile() + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + name := state.Name() + rawState, exists := rawStates[name] + if !exists { + return nil + } + + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + m.states[name] = loadedState + if loadedState != nil { log.Debugf("loaded state: %s", name) } - return nberrors.FormatErrorOrNil(merr) + return nil } -// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. -// If the cleanup is successful, the state is marked for deletion. +// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. +// Unregistered states are preserved in their original state. func (m *Manager) PerformCleanup() error { if m == nil { return nil @@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error { m.mu.Lock() defer m.mu.Unlock() - if err := m.loadState(); err != nil { + // Load raw states from file + rawStates, err := m.loadStateFile() + if err != nil { log.Warnf("Failed to load state during cleanup: %v", err) + return err + } + if rawStates == nil { + return nil } var merr *multierror.Error - for name, state := range m.states { - if state == nil { - // If no state was found in the state file, we don't mark the state dirty nor return an error + + // Process each state in the file + for name, rawState := range rawStates { + // For unregistered states, preserve the raw JSON + if _, registered := m.stateTypes[name]; !registered { + m.states[name] = &RawState{data: rawState} + continue + } + + // Load the registered state + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + + if loadedState == nil { + continue + } + + // Check if state supports cleanup + cleanableState, isCleanable := loadedState.(CleanableState) + if !isCleanable { + // If it doesn't support cleanup, keep it as-is + m.states[name] = loadedState continue } + // Perform cleanup for cleanable states log.Infof("client was not shut down properly, cleaning up %s", name) - if err := state.Cleanup(); err != nil { + if err := cleanableState.Cleanup(); err != nil { merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + // On cleanup error, preserve the state + m.states[name] = loadedState } else { - // mark for deletion on cleanup success + // Successfully cleaned up - mark for deletion m.states[name] = nil m.dirty[name] = struct{}{} } diff --git a/client/server/debug.go b/client/server/debug.go index 6d4546480f0..bf997fe924c 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -7,12 +7,15 @@ import ( "bufio" "bytes" "context" + "encoding/json" "errors" "fmt" "io" + "io/fs" "net" "net/netip" "os" + "path/filepath" "sort" "strings" "time" @@ -23,6 +26,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/proto" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -31,11 +35,14 @@ const readmeContent = `Netbird debug bundle This debug bundle contains the following files: status.txt: Anonymized status information of the NetBird client. -client.log: Most recent, anonymized log file of the NetBird client. +client.log: Most recent, anonymized client log file of the NetBird client. +netbird.err: Most recent, anonymized stderr log file of the NetBird client. +netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. +state.json: Anonymized client state dump containing netbird states. Anonymization Process @@ -65,6 +72,19 @@ The network_map.json file contains the following anonymized information: SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above. +State File +The state.json file contains anonymized internal state information of the NetBird client, including: +- DNS settings and configuration +- Firewall rules +- Exclusion routes +- Route selection +- Other internal states that may be present + +The state file follows the same anonymization rules as other files: +- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure +- Domain names are consistently anonymized +- Technical identifiers and non-sensitive data remain unchanged + Routes For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. @@ -88,6 +108,12 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. ` +const ( + clientLogFile = "client.log" + errorLogFile = "netbird.err" + stdoutLogFile = "netbird.out" +) + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -152,6 +178,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques return fmt.Errorf("add network map: %w", err) } + if err := s.addStateFile(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add state file to debug bundle: %v", err) + } + if err := s.addLogfile(req, anonymizer, archive); err != nil { return fmt.Errorf("add log file: %w", err) } @@ -302,14 +332,73 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym return nil } -func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { - logFile, err := os.Open(s.logFile) +func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + data, err := os.ReadFile(path) if err != nil { - return fmt.Errorf("open log file: %w", err) + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + if req.GetAnonymize() { + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + return fmt.Errorf("unmarshal states: %w", err) + } + + if err := anonymizeStateFile(&rawStates, anonymizer); err != nil { + return fmt.Errorf("anonymize state file: %w", err) + } + + bs, err := json.MarshalIndent(rawStates, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + data = bs + } + + if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil { + return fmt.Errorf("add state file to zip: %w", err) + } + + return nil +} + +func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logDir := filepath.Dir(s.logFile) + + if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil { + return fmt.Errorf("add client log file to zip: %w", err) + } + + errLogPath := filepath.Join(logDir, errorLogFile) + if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", errorLogFile, err) + } + + stdoutLogPath := filepath.Join(logDir, stdoutLogFile) + if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err) + } + + return nil +} + +// addSingleLogfile adds a single log file to the archive +func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logFile, err := os.Open(logPath) + if err != nil { + return fmt.Errorf("open log file %s: %w", targetName, err) } defer func() { if err := logFile.Close(); err != nil { - log.Errorf("Failed to close original log file: %v", err) + log.Errorf("Failed to close log file %s: %v", targetName, err) } }() @@ -318,12 +407,13 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize var writer *io.PipeWriter logReader, writer = io.Pipe() - go s.anonymize(logFile, writer, anonymizer) + go anonymizeLog(logFile, writer, anonymizer) } else { logReader = logFile } - if err := addFileToZip(archive, logReader, "client.log"); err != nil { - return fmt.Errorf("add log file to zip: %w", err) + + if err := addFileToZip(archive, logReader, targetName); err != nil { + return fmt.Errorf("add %s to zip: %w", targetName, err) } return nil @@ -555,6 +645,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an return builder.String() } +func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { + defer func() { + // always nil + _ = writer.Close() + }() + + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := anonymizer.AnonymizeString(scanner.Text()) + if _, err := writer.Write([]byte(line + "\n")); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) + return + } + } + if err := scanner.Err(); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) + return + } +} + func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { anonymizedIPs := make([]string, len(ips)) for i, ip := range ips { @@ -752,3 +862,77 @@ func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *an rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) } } + +func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error { + for name, rawState := range *rawStates { + if string(rawState) == "null" { + continue + } + + var state map[string]any + if err := json.Unmarshal(rawState, &state); err != nil { + return fmt.Errorf("unmarshal state %s: %w", name, err) + } + + state = anonymizeValue(state, anonymizer).(map[string]any) + + bs, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal state %s: %w", name, err) + } + + (*rawStates)[name] = bs + } + + return nil +} + +func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any { + switch v := value.(type) { + case string: + return anonymizeString(v, anonymizer) + case map[string]any: + return anonymizeMap(v, anonymizer) + case []any: + return anonymizeSlice(v, anonymizer) + } + return value +} + +func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(v); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(v); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return anonymizer.AnonymizeString(v) +} + +func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any { + result := make(map[string]any, len(v)) + for key, val := range v { + newKey := anonymizeMapKey(key, anonymizer) + result[newKey] = anonymizeValue(val, anonymizer) + } + return result +} + +func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(key); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(key); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return key +} + +func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any { + for i, val := range v { + v[i] = anonymizeValue(val, anonymizer) + } + return v +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go index 1e3fff6ca65..c8f7bae5d38 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -1,16 +1,271 @@ package server import ( + "encoding/json" "net" "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" mgmProto "github.com/netbirdio/netbird/management/proto" ) +func TestAnonymizeStateFile(t *testing.T) { + testState := map[string]json.RawMessage{ + "null_state": json.RawMessage("null"), + "test_state": mustMarshal(map[string]any{ + // Test simple fields + "public_ip": "203.0.113.1", + "private_ip": "192.168.1.1", + "protected_ip": "100.64.0.1", + "well_known_ip": "8.8.8.8", + "ipv6_addr": "2001:db8::1", + "private_ipv6": "fd00::1", + "domain": "test.example.com", + "uri": "stun:stun.example.com:3478", + "uri_with_ip": "turn:203.0.113.1:3478", + "netbird_domain": "device.netbird.cloud", + + // Test CIDR ranges + "public_cidr": "203.0.113.0/24", + "private_cidr": "192.168.0.0/16", + "protected_cidr": "100.64.0.0/10", + "ipv6_cidr": "2001:db8::/32", + "private_ipv6_cidr": "fd00::/8", + + // Test nested structures + "nested": map[string]any{ + "ip": "203.0.113.2", + "domain": "nested.example.com", + "more_nest": map[string]any{ + "ip": "203.0.113.3", + "domain": "deep.example.com", + }, + }, + + // Test arrays + "string_array": []any{ + "203.0.113.4", + "test1.example.com", + "test2.example.com", + }, + "object_array": []any{ + map[string]any{ + "ip": "203.0.113.5", + "domain": "array1.example.com", + }, + map[string]any{ + "ip": "203.0.113.6", + "domain": "array2.example.com", + }, + }, + + // Test multiple occurrences of same value + "duplicate_ip": "203.0.113.1", // Same as public_ip + "duplicate_domain": "test.example.com", // Same as domain + + // Test URIs with various schemes + "stun_uri": "stun:stun.example.com:3478", + "turns_uri": "turns:turns.example.com:5349", + "http_uri": "http://web.example.com:80", + "https_uri": "https://secure.example.com:443", + + // Test strings that might look like IPs but aren't + "not_ip": "300.300.300.300", + "partial_ip": "192.168", + "ip_like_string": "1234.5678", + + // Test mixed content strings + "mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80", + + // Test empty and special values + "empty_string": "", + "null_value": nil, + "numeric_value": 42, + "boolean_value": true, + }), + "route_state": mustMarshal(map[string]any{ + "routes": []any{ + map[string]any{ + "network": "203.0.113.0/24", + "gateway": "203.0.113.1", + "domains": []any{ + "route1.example.com", + "route2.example.com", + }, + }, + map[string]any{ + "network": "2001:db8::/32", + "gateway": "2001:db8::1", + "domains": []any{ + "route3.example.com", + "route4.example.com", + }, + }, + }, + // Test map with IP/CIDR keys + "refCountMap": map[string]any{ + "203.0.113.1/32": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "2001:db8::1/128": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "fe80::1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "10.0.0.1/32": map[string]any{ // private IP should remain unchanged + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + }, + }, + }, + }), + } + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Pre-seed the domains we need to verify in the test assertions + anonymizer.AnonymizeDomain("test.example.com") + anonymizer.AnonymizeDomain("nested.example.com") + anonymizer.AnonymizeDomain("deep.example.com") + anonymizer.AnonymizeDomain("array1.example.com") + + err := anonymizeStateFile(&testState, anonymizer) + require.NoError(t, err) + + // Helper function to unmarshal and get nested values + var state map[string]any + err = json.Unmarshal(testState["test_state"], &state) + require.NoError(t, err) + + // Test null state remains unchanged + require.Equal(t, "null", string(testState["null_state"])) + + // Basic assertions + assert.NotEqual(t, "203.0.113.1", state["public_ip"]) + assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged + assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged + assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged + assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) + assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "test.example.com", state["domain"]) + assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) + assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged + + // CIDR ranges + assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"]) + assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved + assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged + assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged + assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"]) + assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved + + // Nested structures + nested := state["nested"].(map[string]any) + assert.NotEqual(t, "203.0.113.2", nested["ip"]) + assert.NotEqual(t, "nested.example.com", nested["domain"]) + moreNest := nested["more_nest"].(map[string]any) + assert.NotEqual(t, "203.0.113.3", moreNest["ip"]) + assert.NotEqual(t, "deep.example.com", moreNest["domain"]) + + // Arrays + strArray := state["string_array"].([]any) + assert.NotEqual(t, "203.0.113.4", strArray[0]) + assert.NotEqual(t, "test1.example.com", strArray[1]) + assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain")) + + objArray := state["object_array"].([]any) + firstObj := objArray[0].(map[string]any) + assert.NotEqual(t, "203.0.113.5", firstObj["ip"]) + assert.NotEqual(t, "array1.example.com", firstObj["domain"]) + + // Duplicate values should be anonymized consistently + assert.Equal(t, state["public_ip"], state["duplicate_ip"]) + assert.Equal(t, state["domain"], state["duplicate_domain"]) + + // URIs + assert.NotContains(t, state["stun_uri"], "stun.example.com") + assert.NotContains(t, state["turns_uri"], "turns.example.com") + assert.NotContains(t, state["http_uri"], "web.example.com") + assert.NotContains(t, state["https_uri"], "secure.example.com") + + // Non-IP strings should remain unchanged + assert.Equal(t, "300.300.300.300", state["not_ip"]) + assert.Equal(t, "192.168", state["partial_ip"]) + assert.Equal(t, "1234.5678", state["ip_like_string"]) + + // Mixed content should have IPs and domains replaced + mixedContent := state["mixed_content"].(string) + assert.NotContains(t, mixedContent, "203.0.113.1") + assert.NotContains(t, mixedContent, "test.example.com") + assert.Contains(t, mixedContent, "Server at ") + assert.Contains(t, mixedContent, " on port 80") + + // Special values should remain unchanged + assert.Equal(t, "", state["empty_string"]) + assert.Nil(t, state["null_value"]) + assert.Equal(t, float64(42), state["numeric_value"]) + assert.Equal(t, true, state["boolean_value"]) + + // Check route state + var routeState map[string]any + err = json.Unmarshal(testState["route_state"], &routeState) + require.NoError(t, err) + + routes := routeState["routes"].([]any) + route1 := routes[0].(map[string]any) + assert.NotEqual(t, "203.0.113.0/24", route1["network"]) + assert.Contains(t, route1["network"], "/24") + assert.NotEqual(t, "203.0.113.1", route1["gateway"]) + domains := route1["domains"].([]any) + assert.True(t, strings.HasSuffix(domains[0].(string), ".domain")) + assert.True(t, strings.HasSuffix(domains[1].(string), ".domain")) + + // Check map keys are anonymized + refCountMap := routeState["refCountMap"].(map[string]any) + hasPublicIPKey := false + hasIPv6Key := false + hasPrivateIPKey := false + for key := range refCountMap { + if strings.Contains(key, "203.0.113.1") { + hasPublicIPKey = true + } + if strings.Contains(key, "2001:db8::1") { + hasIPv6Key = true + } + if key == "10.0.0.1/32" { + hasPrivateIPKey = true + } + } + assert.False(t, hasPublicIPKey, "public IP in key should be anonymized") + assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized") + assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged") +} + +func mustMarshal(v any) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return data +} + func TestAnonymizeNetworkMap(t *testing.T) { networkMap := &mgmProto.NetworkMap{ PeerConfig: &mgmProto.PeerConfig{ diff --git a/client/system/info.go b/client/system/info.go index 2af2e637b92..200d835df31 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -61,6 +61,14 @@ type Info struct { Files []File // for posture checks } +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment +} + // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index 6f4ed173b4b..13b0a446bd3 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -10,13 +10,12 @@ import ( "os/exec" "runtime" "strings" + "time" "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, prodName, manufacturer := sysInfo() - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: release, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() diff --git a/client/system/info_linux.go b/client/system/info_linux.go index b6a142bce28..bfc77be1915 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -1,5 +1,4 @@ //go:build !android -// +build !android package system @@ -16,30 +15,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/zcalusic/sysinfo" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) -type SysInfoGetter interface { - GetSysInfo() SysInfo -} - -type SysInfoWrapper struct { - si sysinfo.SysInfo -} - -func (s SysInfoWrapper) GetSysInfo() SysInfo { - s.si.GetSysInfo() - return SysInfo{ - ChassisSerial: s.si.Chassis.Serial, - ProductSerial: s.si.Product.Serial, - BoardSerial: s.si.Board.Serial, - ProductName: s.si.Product.Name, - BoardName: s.si.Board.Name, - ProductVendor: s.si.Product.Vendor, - } -} +var ( + // it is override in tests + getSystemInfo = defaultSysInfoImplementation +) // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { @@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - si := SysInfoWrapper{} - serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info { UIVersion: extractUserAgent(ctx), KernelVersion: osInfo[1], NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } return gio @@ -108,9 +88,9 @@ func _getInfo() string { return out.String() } -func sysInfo(si SysInfo) (string, string, string) { +func sysInfo() (string, string, string) { isascii := regexp.MustCompile("^[[:ascii:]]+$") - + si := getSystemInfo() serials := []string{si.ChassisSerial, si.ProductSerial} serial := "" @@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) { } return serial, name, manufacturer } + +func defaultSysInfoImplementation() SysInfo { + si := sysinfo.SysInfo{} + si.GetSysInfo() + return SysInfo{ + ChassisSerial: si.Chassis.Serial, + ProductSerial: si.Product.Serial, + BoardSerial: si.Board.Serial, + ProductName: si.Product.Name, + BoardName: si.Board.Name, + ProductVendor: si.Product.Vendor, + } +} diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 68631fe164d..28bd3d3007c 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -6,13 +6,12 @@ import ( "os" "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" "golang.org/x/sys/windows/registry" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, err := sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - prodName, err := sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err := sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: buildVersion, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() @@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info { return gio } +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + func getOSNameAndVersion() (string, string) { var dst []Win32_OperatingSystem query := wmi.CreateQuery(&dst, "") diff --git a/client/system/static_info.go b/client/system/static_info.go new file mode 100644 index 00000000000..fabe65a6806 --- /dev/null +++ b/client/system/static_info.go @@ -0,0 +1,46 @@ +//go:build (linux && !android) || windows || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +var ( + staticInfo StaticInfo + once sync.Once +) + +func init() { + go func() { + _ = updateStaticInfo() + }() +} + +func updateStaticInfo() StaticInfo { + once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + staticInfo.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + }) + return staticInfo +} diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go index f6a0b70587b..ae89bfcf974 100644 --- a/client/system/sysinfo_linux_test.go +++ b/client/system/sysinfo_linux_test.go @@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) + getSystemInfo = func() SysInfo { + return tt.sysInfo + } + gotSerialNum, gotProdName, gotManufacturer := sysInfo() if gotSerialNum != tt.wantSerialNum { t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) } diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 0b2b651429e..7793d1fda1d 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c + protocols h1 h2c h3 } } @@ -788,6 +788,7 @@ services: networks: [ netbird ] ports: - '443:443' + - '443:443/udp' - '80:80' - '8080:8080' volumes: diff --git a/management/server/route.go b/management/server/route.go index ecb562645e6..23bea87e3b8 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -417,25 +417,82 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, continue } - policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) - for _, policy := range policies { - if !policy.Enabled { + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { continue } - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} - distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) - routesFirewallRules = append(routesFirewallRules, rules...) +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} } } } - return routesFirewallRules + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers } func getDefaultPermit(route *route.Route) []*RouteFirewallRule { diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e02c..8bf9a3aebb3 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sort" "testing" "time" @@ -1486,6 +1487,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerBIp = "100.65.80.39" peerCIp = "100.65.254.139" peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" ) account := &Account{ @@ -1541,6 +1544,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IP: net.ParseIP(peerHIp), Status: &nbpeer.PeerStatus{}, }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*nbgroup.Group{ "routingPeer1": { @@ -1567,6 +1580,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Name: "Route2", Peers: []string{}, }, + "route4": { + ID: "route4", + Name: "route4", + Peers: []string{}, + }, "finance": { ID: "finance", Name: "Finance", @@ -1584,6 +1602,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { "peerB", }, }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + }, "contractors": { ID: "contractors", Name: "Contractors", @@ -1631,6 +1671,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Groups: []string{"contractors"}, AccessControlGroups: []string{}, }, + "route4": { + ID: "route4", + Network: netip.MustParsePrefix("192.168.10.0/16"), + NetID: "route4", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1"}, + Description: "Route4", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"qa"}, + AccessControlGroups: []string{"route4"}, + }, }, Policies: []*Policy{ { @@ -1685,6 +1738,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, }, }, + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{ + "restrictQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Sources: []string{ + "unrestrictedQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, }, } @@ -1709,7 +1805,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check peer routes firewall rules", func(t *testing.T) { routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 2) + assert.Len(t, routesFirewallRules, 4) expectedRoutesFirewallRules := []*RouteFirewallRule{ { @@ -1735,12 +1831,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + additionalFirewallRule := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "tcp", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "all", + }, + } + + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) @@ -1769,7 +1885,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IsDynamic: true, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) @@ -1778,6 +1894,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } +// orderList is a helper function to sort a list of strings +func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { + for _, rule := range ruleList { + sort.Strings(rule.SourceRanges) + } + return ruleList +} + func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager")