Skip to content

Commit

Permalink
Aggregate overlapping allowed source ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
zarvd committed Aug 15, 2024
1 parent d3f83c6 commit f875524
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 6 deletions.
16 changes: 12 additions & 4 deletions pkg/provider/azure_loadbalancer_accesscontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
Build(),

azureFx.
AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{"0.0.0.0/0", "8.8.8.8/32"}, k8sFx.Service().TCPPorts()).
AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{"0.0.0.0/0"}, k8sFx.Service().TCPPorts()).
WithPriority(501).
WithDestination(azureFx.LoadBalancer().IPv4Addresses()...).
Build(),
Expand All @@ -236,7 +236,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
Build(),

azureFx.
AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, []string{"0.0.0.0/0", "8.8.8.8/32"}, k8sFx.Service().UDPPorts()).
AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, []string{"0.0.0.0/0"}, k8sFx.Service().UDPPorts()).
WithPriority(504).
WithDestination(azureFx.LoadBalancer().IPv4Addresses()...).
Build(),
Expand Down Expand Up @@ -517,7 +517,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
allowedIPv6Ranges = []string{"2607:f0d0:1002:51::/64", "fd00::/8"}
)

svc.Annotations[consts.ServiceAnnotationAllowedIPRanges] = strings.Join(append(allowedIPv4Ranges, allowedIPv6Ranges...), ",")
{
ipRanges := append(allowedIPv4Ranges, allowedIPv6Ranges...)
ipRanges = append(ipRanges, "172.30.0.1/32", "2607:f0d0:1002:51::1/128") // with overlapping CIDRs
svc.Annotations[consts.ServiceAnnotationAllowedIPRanges] = strings.Join(ipRanges, ",")
}

securityGroupClient.EXPECT().
Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any()).
Expand Down Expand Up @@ -695,7 +699,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) {
allowedIPv6Ranges = []string{"2607:f0d0:1002:51::/64", "fd00::/8"}
)

svc.Spec.LoadBalancerSourceRanges = append(allowedIPv4Ranges, allowedIPv6Ranges...)
{
ipRanges := append(allowedIPv4Ranges, allowedIPv6Ranges...)
ipRanges = append(ipRanges, "172.30.0.1/32", "2607:f0d0:1002:51::1/128") // with overlapping CIDRs
svc.Spec.LoadBalancerSourceRanges = ipRanges
}

securityGroupClient.EXPECT().
Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any()).
Expand Down
13 changes: 11 additions & 2 deletions pkg/provider/loadbalancer/accesscontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,23 @@ func (ac *AccessControl) PatchSecurityGroup(dstIPv4Addresses, dstIPv6Addresses [
logger := ac.logger.WithName("PatchSecurityGroup")

var (
allowedIPv4Ranges = ac.AllowedIPv4Ranges()
allowedIPv6Ranges = ac.AllowedIPv6Ranges()
allowedIPRanges = append(ac.AllowedIPv4Ranges(), ac.AllowedIPv6Ranges()...)
allowedServiceTags = ac.AllowedServiceTags
)
if ac.IsAllowFromInternet() {
allowedServiceTags = append(allowedServiceTags, securitygroup.ServiceTagInternet)
}

{
// Aggregate allowed IP ranges.
ipRanges := iputil.AggregatePrefixes(allowedIPRanges)
if len(ipRanges) != len(allowedIPRanges) {
logger.Info("Overlapping IP ranges detected", "allowed-ip-ranges", allowedIPRanges, "aggregated-ip-ranges", ipRanges)
}
allowedIPRanges = ipRanges
}
var allowedIPv4Ranges, allowedIPv6Ranges = iputil.GroupPrefixesByFamily(allowedIPRanges)

logger.V(10).Info("Start patching",
"num-allowed-ipv4-ranges", len(allowedIPv4Ranges),
"num-allowed-ipv6-ranges", len(allowedIPv6Ranges),
Expand Down
38 changes: 38 additions & 0 deletions pkg/provider/loadbalancer/iputil/prefix.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"net/netip"
)

// IsPrefixesAllowAll returns true if one of the prefixes allows all addresses.
// FIXME: it should return true if the aggregated prefix allows all addresses. Now it only checks one by one.
func IsPrefixesAllowAll(prefixes []netip.Prefix) bool {
for _, p := range prefixes {
if p.Bits() == 0 {
Expand All @@ -30,6 +32,7 @@ func IsPrefixesAllowAll(prefixes []netip.Prefix) bool {
return false
}

// ParsePrefix parses a CIDR string and returns a Prefix.
func ParsePrefix(v string) (netip.Prefix, error) {
prefix, err := netip.ParsePrefix(v)
if err != nil {
Expand All @@ -41,3 +44,38 @@ func ParsePrefix(v string) (netip.Prefix, error) {
}
return prefix, nil
}

// GroupPrefixesByFamily groups prefixes by IP family.
func GroupPrefixesByFamily(vs []netip.Prefix) ([]netip.Prefix, []netip.Prefix) {
var (
v4 []netip.Prefix
v6 []netip.Prefix
)
for _, v := range vs {
if v.Addr().Is4() {
v4 = append(v4, v)
} else {
v6 = append(v6, v)
}
}
return v4, v6
}

// AggregatePrefixes aggregates prefixes.
// Overlapping prefixes are merged.
func AggregatePrefixes(prefixes []netip.Prefix) []netip.Prefix {
var (
v4, v6 = GroupPrefixesByFamily(prefixes)
v4Tree = newPrefixTreeForIPv4()
v6Tree = newPrefixTreeForIPv6()
)

for _, p := range v4 {
v4Tree.Add(p)
}
for _, p := range v6 {
v6Tree.Add(p)
}

return append(v4Tree.List(), v6Tree.List()...)
}
190 changes: 190 additions & 0 deletions pkg/provider/loadbalancer/iputil/prefix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package iputil

import (
"fmt"
"net/netip"
"sort"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -114,3 +116,191 @@ func TestParsePrefix(t *testing.T) {
}
})
}

func TestGroupPrefixesByFamily(t *testing.T) {
tests := []struct {
Name string
Input []netip.Prefix
IPv4 []netip.Prefix
IPv6 []netip.Prefix
}{
{
Name: "Empty",
Input: []netip.Prefix{},
},
{
Name: "IPv4",
Input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.1/32"),
netip.MustParsePrefix("10.0.0.0/8"),
},
IPv4: []netip.Prefix{
netip.MustParsePrefix("192.168.0.1/32"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
Name: "IPv6",
Input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::1/128"),
netip.MustParsePrefix("::/0"),
},
IPv6: []netip.Prefix{
netip.MustParsePrefix("2001:db8::1/128"),
netip.MustParsePrefix("::/0"),
},
},
{
Name: "Mixed",
Input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.1/32"),
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("2001:db8::1/128"),
netip.MustParsePrefix("::/0"),
},
IPv4: []netip.Prefix{
netip.MustParsePrefix("192.168.0.1/32"),
netip.MustParsePrefix("10.0.0.0/8"),
},
IPv6: []netip.Prefix{
netip.MustParsePrefix("2001:db8::1/128"),
netip.MustParsePrefix("::/0"),
},
},
}

for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
ipv4, ipv6 := GroupPrefixesByFamily(tt.Input)
assert.Equal(t, tt.IPv4, ipv4)
assert.Equal(t, tt.IPv6, ipv6)
})
}
}

func TestAggregatePrefixes(t *testing.T) {
tests := []struct {
Name string
Input []netip.Prefix
Output []netip.Prefix
}{
{
Name: "Empty",
Input: []netip.Prefix{},
},
{
Name: "NoOverlap IPv4",
Input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("10.10.0.1/32"),
},
Output: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("10.10.0.1/32"),
},
},
{
Name: "Overlap IPv4",
Input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.169.0.0/16"),
netip.MustParsePrefix("10.10.0.1/32"),

netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
Output: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.169.0.0/16"),
netip.MustParsePrefix("10.10.0.1/32"),
},
},
}

for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
var got = AggregatePrefixes(tt.Input)
sort.Slice(got, func(i, j int) bool {
return got[i].String() < got[j].String()
})
sort.Slice(tt.Output, func(i, j int) bool {
return tt.Output[i].String() < tt.Output[j].String()
})
assert.Equal(t, tt.Output, got)
})
}
}

func BenchmarkAggregatePrefixes(b *testing.B) {
fixtureIPv4Prefixes := func(n int64) []netip.Prefix {
prefixes := make([]netip.Prefix, 0, n)
for i := int64(0); i < n; i++ {
addr := netip.AddrFrom4([4]byte{
byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i),
})
prefix, err := addr.Prefix(32)
assert.NoError(b, err)
prefixes = append(prefixes, prefix)
}

return prefixes
}

fixtureIPv6Prefixes := func(n int64) []netip.Prefix {
prefixes := make([]netip.Prefix, 0, n)
for i := int64(0); i < n; i++ {
addr := netip.AddrFrom16([16]byte{
0, 0, 0, 0,
0, 0, 0, 0,
byte(i >> 56), byte(i >> 48), byte(i >> 40), byte(i >> 32),
byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i),
})
prefix, err := addr.Prefix(128)
assert.NoError(b, err)
prefixes = append(prefixes, prefix)
}
return prefixes
}

runIPv4Tests := func(b *testing.B, n int64) {
b.Run(fmt.Sprintf("IPv4-%d", n), func(b *testing.B) {
b.StopTimer()
prefixes := fixtureIPv4Prefixes(n)
b.StartTimer()

for i := 0; i < b.N; i++ {
AggregatePrefixes(prefixes)
}
})
}

runIPv6Tests := func(b *testing.B, n int64) {
b.Run(fmt.Sprintf("IPv6-%d", n), func(b *testing.B) {
b.StopTimer()
prefixes := fixtureIPv4Prefixes(n)
b.StartTimer()

for i := 0; i < b.N; i++ {
AggregatePrefixes(prefixes)
}
})
}

runMixedTests := func(b *testing.B, n int64) {
b.Run(fmt.Sprintf("IPv4-IPv6-%d", 2*n), func(b *testing.B) {
b.StopTimer()
prefixes := append(fixtureIPv4Prefixes(n), fixtureIPv6Prefixes(n)...)
b.StartTimer()

for i := 0; i < b.N; i++ {
AggregatePrefixes(prefixes)
}
})
}

for _, n := range []int64{100, 1_000, 10_000} {
runIPv4Tests(b, n)
runIPv6Tests(b, n)
runMixedTests(b, n)
}
}
Loading

0 comments on commit f875524

Please sign in to comment.