diff --git a/go.mod b/go.mod index 330d0763f5d..d48280df02a 100644 --- a/go.mod +++ b/go.mod @@ -79,7 +79,6 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 - github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index ea459783621..540cbf20bb9 100644 --- a/go.sum +++ b/go.sum @@ -698,8 +698,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7eecdce0fef..162f9037891 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network NetID: route.NetID(n.Name), Description: n.Description, Peer: peer.Key, + PeerID: peer.ID, PeerGroups: nil, Masquerade: router.Masquerade, Metric: router.Metric, diff --git a/management/server/types/account.go b/management/server/types/account.go index 3ef862fa610..f9e1cc9b4e5 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -304,55 +303,47 @@ func (a *Account) GetPeerNetworkMap( return nm } -func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { - missingPeers := map[string]struct{}{} +func (a *Account) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peer *nbpeer.Peer, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) for _, r := range networkResourcesRoutes { - if r.Peer == peer.Key { - continue - } + networkRoutesPeers[r.PeerID] = struct{}{} + } - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if r.Peer == p.Key { - missing = false - break - } - } - if missing { - missingPeers[r.Peer] = struct{}{} - } + delete(sourcePeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) } + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) if isRouter { - for _, s := range sourcePeers { - if s == peer.ID { - continue - } - - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if s == p.ID { - missing = false - break - } - } - if missing { - p, ok := a.Peers[s] - if ok { - missingPeers[p.Key] = struct{}{} - } - } + for p := range sourcePeers { + missingPeers[p] = struct{}{} } } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } for p := range missingPeers { - for _, p2 := range a.Peers { - if p2.Key == p { - peersToConnect = append(peersToConnect, p2) - break - } + if missingPeer := a.Peers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) } } + return peersToConnect } @@ -1156,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli } func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make([]string, 0) + distPeersWithPolicy := make(map[string]struct{}) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1170,16 +1161,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy = append(distPeersWithPolicy, pID) + distPeersWithPolicy[pID] = struct{}{} } } } - radix.Sort(distPeersWithPolicy) - uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) - for _, pID := range uniqueDistributionPeers { + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { peer := a.Peers[pID] if peer == nil { continue @@ -1306,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { } // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { var isRoutingPeer bool var routes []*route.Route - allSourcePeers := make([]string, 0) + allSourcePeers := make(map[string]struct{}, len(a.Peers)) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1326,7 +1314,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st for _, policy := range resourcePolicies[resource.ID] { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) if addSourcePeers { - allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { @@ -1340,8 +1330,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } - radix.Sort(allSourcePeers) - return isRoutingPeer, routes, slices.Compact(allSourcePeers) + return isRoutingPeer, routes, allSourcePeers } func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { @@ -1355,8 +1344,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s } func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { - gObjs := make([]*Group, 0, len(groups)) - tp := 0 + peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { group := a.GetGroup(groupID) if group == nil { @@ -1368,17 +1356,17 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st return group.Peers } - gObjs = append(gObjs, group) - tp += len(group.Peers) + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } } - ids := make([]string, 0, tp) - for _, group := range gObjs { - ids = append(ids, group.Peers...) + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) } - radix.Sort(ids) - return slices.Compact(ids) + return ids } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index efe9301082e..367baef4ff8 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -316,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1"} + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} networkResourcesRoutes := []*route.Route{ - {Peer: "peer2Key"}, - {Peer: "peer3Key"}, + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, } peersToConnect := []*nbpeer.Peer{ - {Key: "peer2Key"}, + {Key: "peer2Key", ID: "peer2"}, } expiredPeers := []*nbpeer.Peer{ - {Key: "peer4Key"}, + {Key: "peer4Key", ID: "peer4"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 2) require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer3Key", result[1].Key) @@ -345,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { } expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -364,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { {Key: "peer3Key"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -376,7 +376,7 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { peersToConnect := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 0) } @@ -559,8 +559,8 @@ func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -599,7 +599,7 @@ func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -692,8 +692,8 @@ func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -741,7 +741,7 @@ func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5b2cf06a032..17964ed1f34 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,11 +1,5 @@ package types -import ( - "slices" - - "github.com/yourbasic/radix" -) - const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -126,10 +120,17 @@ func (p *Policy) SourceGroups() []string { if len(p.Rules) == 1 { return p.Rules[0].Sources } - groups := make([]string, 0) + groups := make(map[string]struct{}, len(p.Rules)) for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) + for _, source := range rule.Sources { + groups[source] = struct{}{} + } + } + + groupIDs := make([]string, 0, len(groups)) + for groupID := range groups { + groupIDs = append(groupIDs, groupID) } - radix.Sort(groups) - return slices.Compact(groups) + + return groupIDs } diff --git a/route/route.go b/route/route.go index 8f3c99b4c1d..ad2aaba8953 100644 --- a/route/route.go +++ b/route/route.go @@ -95,6 +95,7 @@ type Route struct { NetID NetID Description string Peer string + PeerID string `gorm:"-"` PeerGroups []string `gorm:"serializer:json"` NetworkType NetworkType Masquerade bool @@ -120,6 +121,7 @@ func (r *Route) Copy() *Route { KeepRoute: r.KeepRoute, NetworkType: r.NetworkType, Peer: r.Peer, + PeerID: r.PeerID, PeerGroups: slices.Clone(r.PeerGroups), Metric: r.Metric, Masquerade: r.Masquerade, @@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool { other.KeepRoute == r.KeepRoute && other.NetworkType == r.NetworkType && other.Peer == r.Peer && + other.PeerID == r.PeerID && other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled &&