diff --git a/pkg/provider/azure_loadbalancer_accesscontrol_test.go b/pkg/provider/azure_loadbalancer_accesscontrol_test.go index 0547fb17d4..f1b87a15e9 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol_test.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol_test.go @@ -218,7 +218,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Build(), azureFx. - AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{"0.0.0.0/0", "8.8.8.8/32"}, k8sFx.Service().TCPPorts()). + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{"0.0.0.0/0"}, k8sFx.Service().TCPPorts()). WithPriority(501). WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). Build(), @@ -236,7 +236,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Build(), azureFx. - AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, []string{"0.0.0.0/0", "8.8.8.8/32"}, k8sFx.Service().UDPPorts()). + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, []string{"0.0.0.0/0"}, k8sFx.Service().UDPPorts()). WithPriority(504). WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). Build(), @@ -517,7 +517,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { allowedIPv6Ranges = []string{"2607:f0d0:1002:51::/64", "fd00::/8"} ) - svc.Annotations[consts.ServiceAnnotationAllowedIPRanges] = strings.Join(append(allowedIPv4Ranges, allowedIPv6Ranges...), ",") + { + ipRanges := append(allowedIPv4Ranges, allowedIPv6Ranges...) + ipRanges = append(ipRanges, "172.30.0.1/32", "2607:f0d0:1002:51::1/128") // with overlapping CIDRs + svc.Annotations[consts.ServiceAnnotationAllowedIPRanges] = strings.Join(ipRanges, ",") + } securityGroupClient.EXPECT(). Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any()). @@ -695,7 +699,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { allowedIPv6Ranges = []string{"2607:f0d0:1002:51::/64", "fd00::/8"} ) - svc.Spec.LoadBalancerSourceRanges = append(allowedIPv4Ranges, allowedIPv6Ranges...) + { + ipRanges := append(allowedIPv4Ranges, allowedIPv6Ranges...) + ipRanges = append(ipRanges, "172.30.0.1/32", "2607:f0d0:1002:51::1/128") // with overlapping CIDRs + svc.Spec.LoadBalancerSourceRanges = ipRanges + } securityGroupClient.EXPECT(). Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any()). diff --git a/pkg/provider/loadbalancer/accesscontrol.go b/pkg/provider/loadbalancer/accesscontrol.go index 37d66279fb..416ac4fd1f 100644 --- a/pkg/provider/loadbalancer/accesscontrol.go +++ b/pkg/provider/loadbalancer/accesscontrol.go @@ -187,14 +187,23 @@ func (ac *AccessControl) PatchSecurityGroup(dstIPv4Addresses, dstIPv6Addresses [ logger := ac.logger.WithName("PatchSecurityGroup") var ( - allowedIPv4Ranges = ac.AllowedIPv4Ranges() - allowedIPv6Ranges = ac.AllowedIPv6Ranges() + allowedIPRanges = append(ac.AllowedIPv4Ranges(), ac.AllowedIPv6Ranges()...) allowedServiceTags = ac.AllowedServiceTags ) if ac.IsAllowFromInternet() { allowedServiceTags = append(allowedServiceTags, securitygroup.ServiceTagInternet) } + { + // Aggregate allowed IP ranges. + ipRanges := iputil.AggregatePrefixes(allowedIPRanges) + if len(ipRanges) != len(allowedIPRanges) { + logger.Info("Overlapping IP ranges detected", "allowed-ip-ranges", allowedIPRanges, "aggregated-ip-ranges", ipRanges) + } + allowedIPRanges = ipRanges + } + var allowedIPv4Ranges, allowedIPv6Ranges = iputil.GroupPrefixesByFamily(allowedIPRanges) + logger.V(10).Info("Start patching", "num-allowed-ipv4-ranges", len(allowedIPv4Ranges), "num-allowed-ipv6-ranges", len(allowedIPv6Ranges), diff --git a/pkg/provider/loadbalancer/iputil/prefix.go b/pkg/provider/loadbalancer/iputil/prefix.go index fdde2d86e8..77395c37a8 100644 --- a/pkg/provider/loadbalancer/iputil/prefix.go +++ b/pkg/provider/loadbalancer/iputil/prefix.go @@ -21,6 +21,8 @@ import ( "net/netip" ) +// IsPrefixesAllowAll returns true if one of the prefixes allows all addresses. +// FIXME: it should return true if the aggregated prefix allows all addresses. Now it only checks one by one. func IsPrefixesAllowAll(prefixes []netip.Prefix) bool { for _, p := range prefixes { if p.Bits() == 0 { @@ -30,6 +32,7 @@ func IsPrefixesAllowAll(prefixes []netip.Prefix) bool { return false } +// ParsePrefix parses a CIDR string and returns a Prefix. func ParsePrefix(v string) (netip.Prefix, error) { prefix, err := netip.ParsePrefix(v) if err != nil { @@ -41,3 +44,38 @@ func ParsePrefix(v string) (netip.Prefix, error) { } return prefix, nil } + +// GroupPrefixesByFamily groups prefixes by IP family. +func GroupPrefixesByFamily(vs []netip.Prefix) ([]netip.Prefix, []netip.Prefix) { + var ( + v4 []netip.Prefix + v6 []netip.Prefix + ) + for _, v := range vs { + if v.Addr().Is4() { + v4 = append(v4, v) + } else { + v6 = append(v6, v) + } + } + return v4, v6 +} + +// AggregatePrefixes aggregates prefixes. +// Overlapping prefixes are merged. +func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix { + var ( + v4, v6 = GroupPrefixesByFamily(prefixes) + v4Tree = newPrefixTreeForIPv4() + v6Tree = newPrefixTreeForIPv6() + ) + + for _, p := range v4 { + v4Tree.Add(p) + } + for _, p := range v6 { + v6Tree.Add(p) + } + + return append(v4Tree.List(), v6Tree.List()...) +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_test.go b/pkg/provider/loadbalancer/iputil/prefix_test.go index b341e4a23d..d27fa91c6b 100644 --- a/pkg/provider/loadbalancer/iputil/prefix_test.go +++ b/pkg/provider/loadbalancer/iputil/prefix_test.go @@ -17,7 +17,9 @@ limitations under the License. package iputil import ( + "fmt" "net/netip" + "sort" "testing" "github.com/stretchr/testify/assert" @@ -114,3 +116,191 @@ func TestParsePrefix(t *testing.T) { } }) } + +func TestGroupPrefixesByFamily(t *testing.T) { + tests := []struct { + Name string + Input []netip.Prefix + IPv4 []netip.Prefix + IPv6 []netip.Prefix + }{ + { + Name: "Empty", + Input: []netip.Prefix{}, + }, + { + Name: "IPv4", + Input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + IPv4: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + Name: "IPv6", + Input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + netip.MustParsePrefix("::/0"), + }, + IPv6: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + netip.MustParsePrefix("::/0"), + }, + }, + { + Name: "Mixed", + Input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("2001:db8::1/128"), + netip.MustParsePrefix("::/0"), + }, + IPv4: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.1/32"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + IPv6: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + netip.MustParsePrefix("::/0"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + ipv4, ipv6 := GroupPrefixesByFamily(tt.Input) + assert.Equal(t, tt.IPv4, ipv4) + assert.Equal(t, tt.IPv6, ipv6) + }) + } +} + +func TestAggregatePrefixes(t *testing.T) { + tests := []struct { + Name string + Input []netip.Prefix + Output []netip.Prefix + }{ + { + Name: "Empty", + Input: []netip.Prefix{}, + }, + { + Name: "NoOverlap IPv4", + Input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.10.0.1/32"), + }, + Output: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.10.0.1/32"), + }, + }, + { + Name: "Overlap IPv4", + Input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.169.0.0/16"), + netip.MustParsePrefix("10.10.0.1/32"), + + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.1/32"), + }, + Output: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.169.0.0/16"), + netip.MustParsePrefix("10.10.0.1/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + var got = AggregatePrefixes(tt.Input) + sort.Slice(got, func(i, j int) bool { + return got[i].String() < got[j].String() + }) + sort.Slice(tt.Output, func(i, j int) bool { + return tt.Output[i].String() < tt.Output[j].String() + }) + assert.Equal(t, tt.Output, got) + }) + } +} + +func BenchmarkAggregatePrefixes(b *testing.B) { + fixtureIPv4Prefixes := func(n int64) []netip.Prefix { + prefixes := make([]netip.Prefix, 0, n) + for i := int64(0); i < n; i++ { + addr := netip.AddrFrom4([4]byte{ + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, err := addr.Prefix(32) + assert.NoError(b, err) + prefixes = append(prefixes, prefix) + } + + return prefixes + } + + fixtureIPv6Prefixes := func(n int64) []netip.Prefix { + prefixes := make([]netip.Prefix, 0, n) + for i := int64(0); i < n; i++ { + addr := netip.AddrFrom16([16]byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, err := addr.Prefix(128) + assert.NoError(b, err) + prefixes = append(prefixes, prefix) + } + return prefixes + } + + runIPv4Tests := func(b *testing.B, n int64) { + b.Run(fmt.Sprintf("IPv4-%d", n), func(b *testing.B) { + b.StopTimer() + prefixes := fixtureIPv4Prefixes(n) + b.StartTimer() + + for i := 0; i < b.N; i++ { + AggregatePrefixes(prefixes) + } + }) + } + + runIPv6Tests := func(b *testing.B, n int64) { + b.Run(fmt.Sprintf("IPv6-%d", n), func(b *testing.B) { + b.StopTimer() + prefixes := fixtureIPv4Prefixes(n) + b.StartTimer() + + for i := 0; i < b.N; i++ { + AggregatePrefixes(prefixes) + } + }) + } + + runMixedTests := func(b *testing.B, n int64) { + b.Run(fmt.Sprintf("IPv4-IPv6-%d", 2*n), func(b *testing.B) { + b.StopTimer() + prefixes := append(fixtureIPv4Prefixes(n), fixtureIPv6Prefixes(n)...) + b.StartTimer() + + for i := 0; i < b.N; i++ { + AggregatePrefixes(prefixes) + } + }) + } + + for _, n := range []int64{100, 1_000, 10_000} { + runIPv4Tests(b, n) + runIPv6Tests(b, n) + runMixedTests(b, n) + } +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree.go b/pkg/provider/loadbalancer/iputil/prefix_tree.go new file mode 100644 index 0000000000..f8cb78c96d --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/prefix_tree.go @@ -0,0 +1,121 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iputil + +import "net/netip" + +type prefixTreeNode struct { + masked bool + prefix netip.Prefix + + l *prefixTreeNode + r *prefixTreeNode +} + +type prefixTree struct { + maxBits int + root *prefixTreeNode +} + +func newPrefixTreeForIPv4() *prefixTree { + return &prefixTree{ + maxBits: 32, + root: &prefixTreeNode{ + prefix: netip.MustParsePrefix("0.0.0.0/0"), + }, + } +} + +func newPrefixTreeForIPv6() *prefixTree { + return &prefixTree{ + maxBits: 128, + root: &prefixTreeNode{ + prefix: netip.MustParsePrefix("::/0"), + }, + } +} + +// Add adds a prefix to the tree. +func (t *prefixTree) Add(prefix netip.Prefix) { + var ( + n = t.root + bits = prefix.Addr().AsSlice() + ) + for i := 0; i < prefix.Bits(); i++ { + if n.masked { + break // It's already masked, the rest of the bits are irrelevant + } + + var bit = bits[i/8] >> (7 - i%8) & 1 + switch bit { + case 0: + if n.l == nil { + next, err := prefix.Addr().Prefix(i + 1) + if err != nil { + panic("unreachable: invalid prefix") + } + n.l = &prefixTreeNode{ + prefix: next, + } + } + n = n.l + case 1: + if n.r == nil { + next, err := prefix.Addr().Prefix(i + 1) + if err != nil { + panic("unreachable: invalid prefix") + } + n.r = &prefixTreeNode{ + prefix: next, + } + } + n = n.r + default: + panic("unreachable: unexpected bit") + } + } + + n.masked = true +} + +// List returns all prefixes in the tree. +// Overlapping prefixes are merged. +func (t *prefixTree) List() []netip.Prefix { + var ( + rv []netip.Prefix + q = []*prefixTreeNode{t.root} + ) + + for len(q) > 0 { + n := q[len(q)-1] + q = q[:len(q)-1] + + if n.masked { + rv = append(rv, n.prefix) + continue + } + + if n.l != nil { + q = append(q, n.l) + } + if n.r != nil { + q = append(q, n.r) + } + } + + return rv +} diff --git a/pkg/provider/loadbalancer/iputil/prefix_tree_test.go b/pkg/provider/loadbalancer/iputil/prefix_tree_test.go new file mode 100644 index 0000000000..5a21a98eff --- /dev/null +++ b/pkg/provider/loadbalancer/iputil/prefix_tree_test.go @@ -0,0 +1,217 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iputil + +import ( + "math" + "net/netip" + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPrefixTreeIPv4(t *testing.T) { + tests := []struct { + Name string + Input []string + Output []string + }{ + { + "Empty", + []string{}, + nil, + }, + { + "NoOverlap", + []string{ + "192.168.0.0/16", + "10.10.0.1/32", + }, + []string{ + "192.168.0.0/16", + "10.10.0.1/32", + }, + }, + { + "Overlap", + []string{ + "192.168.0.0/16", + "192.169.0.0/16", + "10.10.0.1/32", + + "192.168.1.0/24", + "192.168.1.1/32", + }, + []string{ + "192.168.0.0/16", + "192.169.0.0/16", + "10.10.0.1/32", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + var tree = newPrefixTreeForIPv4() + for _, ip := range tt.Input { + p := netip.MustParsePrefix(ip) + tree.Add(p) + } + + var got []string + for _, ip := range tree.List() { + got = append(got, ip.String()) + } + + sort.Strings(got) + sort.Strings(tt.Output) + + assert.Equal(t, tt.Output, got) + }) + } +} + +func TestPrefixTreeIPv6(t *testing.T) { + tests := []struct { + Name string + Input []string + Output []string + }{ + { + "Empty", + []string{}, + nil, + }, + { + "NoOverlap", + []string{ + "2001:db8:0:1::/64", + "2001:db8:0:2::/64", + "2001:db8:0:3::/64", + }, + []string{ + "2001:db8:0:1::/64", + "2001:db8:0:2::/64", + "2001:db8:0:3::/64", + }, + }, + { + "Overlap", + []string{ + "2001:db8::/32", + "2001:db8:0:1::/64", + "2001:db8:0:2::/64", + "2001:db8:0:3::/64", + }, + []string{ + "2001:db8::/32", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + var tree = newPrefixTreeForIPv6() + for _, ip := range tt.Input { + p := netip.MustParsePrefix(ip) + tree.Add(p) + } + + var got []string + for _, ip := range tree.List() { + got = append(got, ip.String()) + } + + sort.Strings(got) + sort.Strings(tt.Output) + + assert.Equal(t, tt.Output, got) + }) + } +} + +func BenchmarkPrefixTree_Add(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + var tree = newPrefixTreeForIPv4() + for i := 0; i < b.N; i++ { + addr := netip.AddrFrom4([4]byte{ + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, _ := addr.Prefix(32) + + tree.Add(prefix) + } + }) + + b.Run("IPv6", func(b *testing.B) { + var tree = newPrefixTreeForIPv6() + for i := 0; i < b.N; i++ { + addr := netip.AddrFrom16([16]byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, _ := addr.Prefix(128) + + tree.Add(prefix) + } + }) +} + +func BenchmarkPrefixTree_List(b *testing.B) { + + b.Run("IPv4", func(b *testing.B) { + b.StopTimer() + var tree = newPrefixTreeForIPv4() + for i := 0; i < math.MaxInt8; i++ { + addr := netip.AddrFrom4([4]byte{ + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, err := addr.Prefix(32) + assert.NoError(b, err) + + tree.Add(prefix) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + tree.List() + } + }) + + b.Run("IPv6", func(b *testing.B) { + b.StopTimer() + var tree = newPrefixTreeForIPv6() + for i := 0; i < math.MaxInt8; i++ { + addr := netip.AddrFrom16([16]byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32), + byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i), + }) + prefix, err := addr.Prefix(128) + assert.NoError(b, err) + + tree.Add(prefix) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + tree.List() + } + }) +} diff --git a/tests/e2e/network/network_security_group.go b/tests/e2e/network/network_security_group.go index d5d940c6c5..acd8dbba02 100644 --- a/tests/e2e/network/network_security_group.go +++ b/tests/e2e/network/network_security_group.go @@ -238,6 +238,21 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( allowedIPv6Ranges = []string{ "2c0f:fe40:8000::/48", "2c0f:feb0::/43", } + + // The overlapping IP ranges will be aggregated after reconciled + overlappingIPv4Ranges = []string{ + "10.20.8.0/24", + "10.20.9.0/25", + "10.20.8.1/32", + "192.168.0.1/32", + } + overlappingIPv6Ranges = []string{ + "2c0f:fe40:8000::/49", + "2c0f:fe40:8000:1111::/64", + "2c0f:feb0::/43", + "2c0f:feb0::/44", + "2c0f:feb0::1/128", + } ) By("Creating a LoadBalancer service", func() { @@ -246,7 +261,13 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( "app": ServiceName, } annotations = map[string]string{ - v1.AnnotationLoadBalancerSourceRangesKey: strings.Join(append(allowedIPv4Ranges, allowedIPv6Ranges...), ","), + v1.AnnotationLoadBalancerSourceRangesKey: strings.Join( + append( + append(allowedIPv4Ranges, overlappingIPv4Ranges...), + append(allowedIPv6Ranges, overlappingIPv6Ranges...)..., + ), + ",", + ), } ports = []v1.ServicePort{{ Port: serverPort, @@ -619,6 +640,21 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( allowedIPv6Ranges = []string{ "2c0f:fe40:8000::/48", "2c0f:feb0::/43", } + + // The overlapping IP ranges will be aggregated after reconciled + overlappingIPv4Ranges = []string{ + "10.20.8.0/24", + "10.20.9.0/25", + "10.20.8.1/32", + "192.168.0.1/32", + } + overlappingIPv6Ranges = []string{ + "2c0f:fe40:8000::/49", + "2c0f:fe40:8000:1111::/64", + "2c0f:feb0::/43", + "2c0f:feb0::/44", + "2c0f:feb0::1/128", + } ) By("Creating a LoadBalancer service", func() { @@ -627,7 +663,13 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( "app": ServiceName, } annotations = map[string]string{ - consts.ServiceAnnotationAllowedIPRanges: strings.Join(append(allowedIPv4Ranges, allowedIPv6Ranges...), ","), + consts.ServiceAnnotationAllowedIPRanges: strings.Join( + append( + append(allowedIPv4Ranges, overlappingIPv4Ranges...), + append(allowedIPv6Ranges, overlappingIPv6Ranges...)..., + ), + ",", + ), } ports = []v1.ServicePort{{ Port: serverPort,