Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mimic ipset C code for determining correct default ipset revision for hash:ip{port,net,etc} #1031

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 97 additions & 16 deletions ipset_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename)))

cadtFlags := optionsToBitflag(options)

revision := options.Revision
if revision == 0 {
revision = getIpsetDefaultWithTypeName(typename)
revision = getIpsetDefaultRevision(typename, cadtFlags)
}
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(revision)))

Expand Down Expand Up @@ -181,18 +183,6 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout})
}

var cadtFlags uint32

if options.Comments {
cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT
}
if options.Counters {
cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS
}
if options.Skbinfo {
cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO
}

if cadtFlags != 0 {
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER, Value: cadtFlags})
}
Expand Down Expand Up @@ -395,14 +385,89 @@ func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest {
return req
}

func getIpsetDefaultWithTypeName(typename string) uint8 {
// NOTE: This can't just take typename into account, it also has to take desired
// feature support into account, on a per-set-type basis, to return the correct revision, see e.g.
// https://github.com/Olipro/ipset/blob/9f145b49100104d6570fe5c31a5236816ebb4f8f/kernel/net/netfilter/ipset/ip_set_hash_ipport.c#L30
//
// This means that whenever a new "type" of ipset is added, returning the "correct" default revision
// requires adding a new case here for that type, and consulting the ipset C code to figure out the correct
// combination of type name, feature bit flags, and revision ranges.
//
// Care should be taken as some types share the same revision ranges for the same features, and others do not.
// When in doubt, mimic the C code.
func getIpsetDefaultRevision(typename string, featureFlags uint32) uint8 {
switch typename {
case "hash:ip,port",
"hash:ip,port,ip",
"hash:ip,port,net",
"hash:ip,port,ip":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipport.c
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportip.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 5
}

if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 4
}

if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 3
}

if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 {
return 2
}

// the min revision this library supports for this type
return 1

case "hash:ip,port,net",
"hash:net,port":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportnet.c
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_netport.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 7
}

if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 6
}

if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 5
}

if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 {
return 4
}

if (featureFlags & nl.IPSET_FLAG_NOMATCH) != 0 {
return 3
}
// the min revision this library supports for this type
return 2

case "hash:ip":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ip.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 4
}

if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 3
}

if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 2
}

// the min revision this library supports for this type
return 1
}

// can't map the correct revision for this type.
return 0
}

Expand Down Expand Up @@ -579,3 +644,19 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) {
}
return
}

func optionsToBitflag(options IpsetCreateOptions) uint32 {
var cadtFlags uint32

if options.Comments {
cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT
}
if options.Counters {
cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS
}
if options.Skbinfo {
cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO
}

return cadtFlags
}
79 changes: 76 additions & 3 deletions ipset_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package netlink

import (
"bytes"
"io/ioutil"
"net"
"os"
"testing"

"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)

func TestParseIpsetProtocolResult(t *testing.T) {
msgBytes, err := ioutil.ReadFile("testdata/ipset_protocol_result")
msgBytes, err := os.ReadFile("testdata/ipset_protocol_result")
if err != nil {
t.Fatalf("reading test fixture failed: %v", err)
}
Expand All @@ -23,7 +23,7 @@ func TestParseIpsetProtocolResult(t *testing.T) {
}

func TestParseIpsetListResult(t *testing.T) {
msgBytes, err := ioutil.ReadFile("testdata/ipset_list_result")
msgBytes, err := os.ReadFile("testdata/ipset_list_result")
if err != nil {
t.Fatalf("reading test fixture failed: %v", err)
}
Expand Down Expand Up @@ -759,3 +759,76 @@ func TestIpsetMaxElements(t *testing.T) {
t.Fatalf("expected '%d' entry be created, got '%d'", maxElements, len(result.Entries))
}
}

func TestIpsetDefaultRevision(t *testing.T) {
testCases := []struct {
desc string
typename string
options IpsetCreateOptions
expectedRevision uint8
}{
{
desc: "Type-hash:ip,port",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: true,
Skbinfo: false,
},
expectedRevision: 3,
},
{
desc: "Type-hash:ip,port_nocomment",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: false,
},
expectedRevision: 2,
},
{
desc: "Type-hash:ip,port_skbinfo",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: true,
},
expectedRevision: 5,
},
{
desc: "Type-hash:ip,port,net",
typename: "hash:ip,port,net",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: true,
},
expectedRevision: 7,
},
{
desc: "Type-hash:net,port_baseline_revision_no_opts",
typename: "hash:net,port",
options: IpsetCreateOptions{
Counters: false,
Comments: false,
Skbinfo: false,
},
expectedRevision: 2,
},
}

for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {

cadtFlags := optionsToBitflag(tC.options)

defRev := getIpsetDefaultRevision(tC.typename, cadtFlags)

if defRev != tC.expectedRevision {
t.Fatalf("expected default revision of '%d', got '%d'", tC.expectedRevision, defRev)
}
})
}
}
Loading