diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d32c..5b0c50c9cd6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -417,25 +417,82 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, continue } - policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) - for _, policy := range policies { - if !policy.Enabled { + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { continue } - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} - distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) - routesFirewallRules = append(routesFirewallRules, rules...) +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} } } } - return routesFirewallRules + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers } func getDefaultPermit(route *route.Route) []*RouteFirewallRule { diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c7b..0afbebbce89 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sort" "testing" "time" @@ -1487,6 +1488,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerBIp = "100.65.80.39" peerCIp = "100.65.254.139" peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" ) account := &Account{ @@ -1542,6 +1545,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IP: net.ParseIP(peerHIp), Status: &nbpeer.PeerStatus{}, }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*nbgroup.Group{ "routingPeer1": { @@ -1568,6 +1581,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Name: "Route2", Peers: []string{}, }, + "route4": { + ID: "route4", + Name: "route4", + Peers: []string{}, + }, "finance": { ID: "finance", Name: "Finance", @@ -1585,6 +1603,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { "peerB", }, }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + }, "contractors": { ID: "contractors", Name: "Contractors", @@ -1632,6 +1672,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Groups: []string{"contractors"}, AccessControlGroups: []string{}, }, + "route4": { + ID: "route4", + Network: netip.MustParsePrefix("192.168.10.0/16"), + NetID: "route4", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1"}, + Description: "Route4", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"qa"}, + AccessControlGroups: []string{"route4"}, + }, }, Policies: []*Policy{ { @@ -1686,6 +1739,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, }, }, + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{ + "restrictQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Sources: []string{ + "unrestrictedQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, }, } @@ -1710,7 +1806,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check peer routes firewall rules", func(t *testing.T) { routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 2) + assert.Len(t, routesFirewallRules, 4) expectedRoutesFirewallRules := []*RouteFirewallRule{ { @@ -1736,12 +1832,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + additionalFirewallRule := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "tcp", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "all", + }, + } + + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) @@ -1770,7 +1886,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IsDynamic: true, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) @@ -1779,6 +1895,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } +// orderList is a helper function to sort a list of strings +func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { + for _, rule := range ruleList { + sort.Strings(rule.SourceRanges) + } + return ruleList +} + func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager")