diff --git a/pkg/wekafs/apiclient/nfs.go b/pkg/wekafs/apiclient/nfs.go index ec2f9612f..b0617ee38 100644 --- a/pkg/wekafs/apiclient/nfs.go +++ b/pkg/wekafs/apiclient/nfs.go @@ -495,6 +495,15 @@ func (r *NfsClientGroupRule) EQ(other ApiObject) bool { return ObjectsAreEqual(r, other) } +func (r *NfsClientGroupRule) IsSupersetOf(other *NfsClientGroupRule) bool { + if r.IsIPRule() && other.IsIPRule() { + n1 := r.GetNetwork() + n2 := other.GetNetwork() + return n1.ContainsIPAddress(n2.IP.String()) + } + return false +} + func (r *NfsClientGroupRule) getImmutableFields() []string { return []string{"Rule"} } @@ -566,6 +575,9 @@ func (a *ApiClient) FindNfsClientGroupRulesByFilter(ctx context.Context, query * for _, r := range ret { if r.EQ(query) { *resultSet = append(*resultSet, r) + } else if r.IsSupersetOf(query) { + // if we have a rule that covers the IP address by bigger network segment, also add it + *resultSet = append(*resultSet, r) } } return nil diff --git a/pkg/wekafs/apiclient/nfs_test.go b/pkg/wekafs/apiclient/nfs_test.go index 8dd41e2ea..7b363f455 100644 --- a/pkg/wekafs/apiclient/nfs_test.go +++ b/pkg/wekafs/apiclient/nfs_test.go @@ -312,3 +312,33 @@ func TestInterfaceGroup(t *testing.T) { // } //} } + +func TestIsSupersetOf(t *testing.T) { + // Test case 1: IP rule superset + rule1 := &NfsClientGroupRule{ + Type: NfsClientGroupRuleTypeIP, + Rule: "192.168.1.0/24", + } + rule2 := &NfsClientGroupRule{ + Type: NfsClientGroupRuleTypeIP, + Rule: "192.168.1.1", + } + assert.True(t, rule1.IsSupersetOf(rule2)) + + // Test case 2: IP rule not superset + rule3 := &NfsClientGroupRule{ + Type: NfsClientGroupRuleTypeIP, + Rule: "192.168.2.0/24", + } + assert.False(t, rule1.IsSupersetOf(rule3)) + + // Test case 3: Non-IP rule + rule4 := &NfsClientGroupRule{ + Type: NfsClientGroupRuleTypeDNS, + Rule: "example.com", + } + assert.False(t, rule1.IsSupersetOf(rule4)) + + // Test case 4: Same rule + assert.True(t, rule1.IsSupersetOf(rule1)) +} diff --git a/pkg/wekafs/apiclient/utils.go b/pkg/wekafs/apiclient/utils.go index a96cedf0d..56d43376b 100644 --- a/pkg/wekafs/apiclient/utils.go +++ b/pkg/wekafs/apiclient/utils.go @@ -1,6 +1,7 @@ package apiclient import ( + "encoding/binary" "fmt" "github.com/rs/zerolog/log" "hash/fnv" @@ -73,6 +74,24 @@ func (n *Network) AsNfsRule() string { return fmt.Sprintf("%s/%s", n.IP.String(), n.Subnet.String()) } +func (n *Network) GetMaskBits() int { + ip := n.Subnet.To4() + if ip == nil { + return 0 + } + // Count the number of 1 bits + mask := binary.BigEndian.Uint32(ip) + + // Count the number of set bits + cidrBits := 0 + for mask != 0 { + cidrBits += int(mask & 1) + mask >>= 1 + } + + return cidrBits +} + func parseNetworkString(s string) (*Network, error) { var ip, subnet net.IP if strings.Contains(s, "/") { @@ -107,8 +126,11 @@ func (n *Network) ContainsIPAddress(ipStr string) bool { } _, ipNet, err := net.ParseCIDR(fmt.Sprintf("%s/%s", n.IP.String(), n.Subnet.String())) - if err != nil { - return false + if err != nil || ipNet == nil { + _, ipNet, err = net.ParseCIDR(fmt.Sprintf("%s/%d", n.IP.String(), n.GetMaskBits())) + if err != nil || ipNet == nil { + return false + } } return ipNet.Contains(ip) }