From 53f7f2f6d03188ed67fc4812e601e1d58410132c Mon Sep 17 00:00:00 2001 From: Francesco Cheinasso Date: Thu, 24 Aug 2023 10:51:18 +0200 Subject: [PATCH] Iptables mode selection --- iptables/iptables.go | 59 +++++++++++----- iptables/iptables_test.go | 138 +++++++++++++++++++++++++------------- 2 files changed, 134 insertions(+), 63 deletions(-) diff --git a/iptables/iptables.go b/iptables/iptables.go index e95929c..1b45c94 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -64,6 +64,18 @@ const ( ProtocolIPv6 ) +// Mode to differentiate between legacy and nf_tables +type ModeType string + +const ( + // ModeTypeAuto is the default mode, which uses the system default + ModeTypeAuto ModeType = "auto" + // ModeTypeLegacy forces the use of the legacy iptables mode + ModeTypeLegacy ModeType = "legacy" + // ModeTypeNFTables forces the use of the nf_tables iptables mode + ModeTypeNFTables ModeType = "nf_tables" +) + type IPTables struct { path string proto Protocol @@ -74,8 +86,8 @@ type IPTables struct { v1 int v2 int v3 int - mode string // the underlying iptables operating mode, e.g. nf_tables - timeout int // time to wait for the iptables lock, default waits forever + mode ModeType // the underlying iptables operating mode, e.g. nf_tables + timeout int // time to wait for the iptables lock, default waits forever } // Stat represents a structured statistic entry. @@ -106,6 +118,12 @@ func Timeout(timeout int) option { } } +func Mode(mode ModeType) option { + return func(ipt *IPTables) { + ipt.mode = mode + } +} + // New creates a new IPTables configured with the options passed as parameter. // For backwards compatibility, by default always uses IPv4 and timeout 0. // i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing @@ -116,6 +134,7 @@ func New(opts ...option) (*IPTables, error) { ipt := &IPTables{ proto: ProtocolIPv4, + mode: ModeTypeAuto, timeout: 0, } @@ -123,7 +142,7 @@ func New(opts ...option) (*IPTables, error) { opt(ipt) } - path, err := exec.LookPath(getIptablesCommand(ipt.proto)) + path, err := exec.LookPath(getIptablesCommand(ipt.proto, ipt.mode)) if err != nil { return nil, err } @@ -133,14 +152,13 @@ func New(opts ...option) (*IPTables, error) { if err != nil { return nil, fmt.Errorf("could not get iptables version: %v", err) } - v1, v2, v3, mode, err := extractIptablesVersion(vstring) + v1, v2, v3, _, err := extractIptablesVersion(vstring) if err != nil { return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err) } ipt.v1 = v1 ipt.v2 = v2 ipt.v3 = v3 - ipt.mode = mode checkPresent, waitPresent, waitSupportSecond, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3) ipt.hasCheck = checkPresent @@ -518,8 +536,8 @@ func (ipt *IPTables) HasRandomFully() bool { } // Return version components of the underlying iptables command -func (ipt *IPTables) GetIptablesVersion() (int, int, int) { - return ipt.v1, ipt.v2, ipt.v3 +func (ipt *IPTables) GetIptablesVersion() (int, int, int, ModeType) { + return ipt.v1, ipt.v2, ipt.v3, ipt.mode } // run runs an iptables command with the given arguments, ignoring @@ -573,12 +591,23 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { } // getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables". -func getIptablesCommand(proto Protocol) string { - if proto == ProtocolIPv6 { - return "ip6tables" - } else { - return "iptables" +func getIptablesCommand(proto Protocol, mode ModeType) string { + var cmd string + switch proto { + case ProtocolIPv4: + cmd = "iptables" + case ProtocolIPv6: + cmd = "ip6tables" + } + // Append a suffix to the command to get the correct binary, + // If the mode is auto (default), the suffix is not applied and the system default is used. + switch mode { + case ModeTypeNFTables: + cmd = fmt.Sprintf("%s-nftdsad", cmd) + case ModeTypeLegacy: + cmd = fmt.Sprintf("%s-legacy", cmd) } + return cmd } // Checks if iptables has the "-C" and "--wait" flag @@ -589,7 +618,7 @@ func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool) // getIptablesVersion returns the first three components of the iptables version // and the operating mode (e.g. nf_tables or legacy) // e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil) -func extractIptablesVersion(str string) (int, int, int, string, error) { +func extractIptablesVersion(str string) (int, int, int, ModeType, error) { versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`) result := versionMatcher.FindStringSubmatch(str) if result == nil { @@ -611,9 +640,9 @@ func extractIptablesVersion(str string) (int, int, int, string, error) { return 0, 0, 0, "", err } - mode := "legacy" + mode := ModeTypeLegacy if result[4] != "" { - mode = result[4] + mode = ModeType(result[4]) } return v1, v2, v3, mode, nil } diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go index cc2de33..f18ae43 100644 --- a/iptables/iptables_test.go +++ b/iptables/iptables_test.go @@ -25,6 +25,23 @@ import ( "testing" ) +var ( + protos = []Protocol{ProtocolIPv4, ProtocolIPv6} + modes = []ModeType{ModeTypeAuto, ModeTypeLegacy, ModeTypeNFTables} +) + +// getProtoName returns the name of the protocol, for use in test names. +func getProtoName(proto Protocol) string { + switch proto { + case ProtocolIPv4: + return "IPv4" + case ProtocolIPv6: + return "IPv6" + default: + panic("unknown protocol") + } +} + func TestProto(t *testing.T) { ipt, err := New() if err != nil { @@ -34,40 +51,72 @@ func TestProto(t *testing.T) { t.Fatalf("Expected default protocol IPv4, got %v", ipt.Proto()) } - ip4t, err := NewWithProtocol(ProtocolIPv4) - if err != nil { - t.Fatalf("NewWithProtocol(ProtocolIPv4) failed: %v", err) - } - if ip4t.Proto() != ProtocolIPv4 { - t.Fatalf("Expected protocol IPv4, got %v", ip4t.Proto()) + for _, proto := range protos { + protoName := getProtoName(proto) + ipt, err := New(IPFamily(proto)) + if err != nil { + t.Fatalf("NewWithProtocol(%s) failed: %v", protoName, err) + } + if ipt.Proto() != proto { + t.Fatalf("Expected protocol %s, got %v", protoName, ipt.Proto()) + } + if ipt.mode != ModeTypeAuto { + t.Fatalf("Expected mode auto, got %v", ipt.mode) + } } - ip6t, err := NewWithProtocol(ProtocolIPv6) - if err != nil { - t.Fatalf("NewWithProtocol(ProtocolIPv6) failed: %v", err) - } - if ip6t.Proto() != ProtocolIPv6 { - t.Fatalf("Expected protocol IPv6, got %v", ip6t.Proto()) + for _, proto := range protos { + for _, mode := range modes { + protoName := getProtoName(proto) + ipt, err := New(Mode(mode), IPFamily(proto)) + if err != nil { + t.Fatalf("New(Mode(%v), IPFamily(%v)) failed: %v", mode, protoName, err) + } + if ipt.Proto() != proto { + t.Fatalf("Expected protocol %v, got %v", protoName, ipt.Proto()) + } + if ipt.mode != mode { + t.Fatalf("Expected mode %v, got %v", mode, ipt.mode) + } + } } } func TestTimeout(t *testing.T) { - ipt, err := New() - if err != nil { - t.Fatalf("New failed: %v", err) - } - if ipt.timeout != 0 { - t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout) - } + for _, proto := range protos { + for _, mode := range modes { + ipt, err := New(IPFamily(proto), Mode(mode)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + if ipt.timeout != 0 { + t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout) + } - ipt2, err := New(Timeout(5)) - if err != nil { - t.Fatalf("New failed: %v", err) - } - if ipt2.timeout != 5 { - t.Fatalf("Expected timeout 5, got %v", ipt.timeout) + ipt2, err := New(Timeout(5)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + if ipt2.timeout != 5 { + t.Fatalf("Expected timeout 5, got %v", ipt.timeout) + } + } } +} +func TestGetIptablesVersionMode(t *testing.T) { + for _, proto := range protos { + for _, mode := range modes { + ipt, err := New(IPFamily(proto), Mode(mode)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + _, _, _, getmode := ipt.GetIptablesVersion() + if getmode != mode { + t.Fatalf("Expected mode %v, got %v", mode, mode) + } + } + } } func randChain(t *testing.T) string { @@ -92,27 +141,20 @@ func contains(list []string, value string) bool { // features enabled & disabled, to test compatibility. // We used to test noWait as well, but that was removed as of iptables v1.6.0 func mustTestableIptables() []*IPTables { - ipt, err := New() - if err != nil { - panic(fmt.Sprintf("New failed: %v", err)) - } - ip6t, err := NewWithProtocol(ProtocolIPv6) - if err != nil { - panic(fmt.Sprintf("NewWithProtocol(ProtocolIPv6) failed: %v", err)) - } - ipts := []*IPTables{ipt, ip6t} - - // ensure we check one variant without built-in checking - if ipt.hasCheck { - i := *ipt - i.hasCheck = false - ipts = append(ipts, &i) - - i6 := *ip6t - i6.hasCheck = false - ipts = append(ipts, &i6) - } else { - panic("iptables on this machine is too old -- missing -C") + ipts := []*IPTables{} + for _, proto := range protos { + for _, mode := range modes { + ipt, err := New(IPFamily(proto), Mode(mode)) + if err != nil { + panic(fmt.Sprintf("New(IPFamily(%v), Mode(%v)) failed: %v", proto, mode, err)) + } + if ipt.hasCheck { + ipt.hasCheck = false + ipts = append(ipts, ipt) + } else { + panic("iptables on this machine is too old -- missing -C") + } + } } return ipts } @@ -251,7 +293,7 @@ func TestRules(t *testing.T) { } func runRulesTests(t *testing.T, ipt *IPTables) { - t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck) + t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto(), ModeTypeAuto), ipt.hasWait, ipt.hasCheck) var address1, address2, subnet1, subnet2 string if ipt.Proto() == ProtocolIPv6 { @@ -689,7 +731,7 @@ func TestExtractIptablesVersion(t *testing.T) { t.Fatalf("unexpected err %s", err) } - if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode { + if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != ModeType(tt.mode) { t.Fatalf("expected %d %d %d %s, got %d %d %d %s", tt.v1, tt.v2, tt.v3, tt.mode, v1, v2, v3, mode)