Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iptmodes #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ const (
ProtocolIPv6
)

// Mode to differentiate between legacy and nf_tables
type ModeType string

const (
// ModeTypeAuto is the default mode, which uses the system default
ModeTypeAuto ModeType = "auto"
// ModeTypeLegacy forces the use of the legacy iptables mode
ModeTypeLegacy ModeType = "legacy"
// ModeTypeNFTables forces the use of the nf_tables iptables mode
ModeTypeNFTables ModeType = "nf_tables"
)

type IPTables struct {
path string
proto Protocol
Expand All @@ -74,8 +86,8 @@ type IPTables struct {
v1 int
v2 int
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
timeout int // time to wait for the iptables lock, default waits forever
mode ModeType // the underlying iptables operating mode, e.g. nf_tables
timeout int // time to wait for the iptables lock, default waits forever
}

// Stat represents a structured statistic entry.
Expand Down Expand Up @@ -106,6 +118,12 @@ func Timeout(timeout int) option {
}
}

func Mode(mode ModeType) option {
return func(ipt *IPTables) {
ipt.mode = mode
}
}

// New creates a new IPTables configured with the options passed as parameter.
// For backwards compatibility, by default always uses IPv4 and timeout 0.
// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
Expand All @@ -116,14 +134,15 @@ func New(opts ...option) (*IPTables, error) {

ipt := &IPTables{
proto: ProtocolIPv4,
mode: ModeTypeAuto,
timeout: 0,
}

for _, opt := range opts {
opt(ipt)
}

path, err := exec.LookPath(getIptablesCommand(ipt.proto))
path, err := exec.LookPath(getIptablesCommand(ipt.proto, ipt.mode))
if err != nil {
return nil, err
}
Expand All @@ -133,14 +152,13 @@ func New(opts ...option) (*IPTables, error) {
if err != nil {
return nil, fmt.Errorf("could not get iptables version: %v", err)
}
v1, v2, v3, mode, err := extractIptablesVersion(vstring)
v1, v2, v3, _, err := extractIptablesVersion(vstring)
if err != nil {
return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
}
ipt.v1 = v1
ipt.v2 = v2
ipt.v3 = v3
ipt.mode = mode

checkPresent, waitPresent, waitSupportSecond, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
ipt.hasCheck = checkPresent
Expand Down Expand Up @@ -518,8 +536,8 @@ func (ipt *IPTables) HasRandomFully() bool {
}

// Return version components of the underlying iptables command
func (ipt *IPTables) GetIptablesVersion() (int, int, int) {
return ipt.v1, ipt.v2, ipt.v3
func (ipt *IPTables) GetIptablesVersion() (int, int, int, ModeType) {
return ipt.v1, ipt.v2, ipt.v3, ipt.mode
}

// run runs an iptables command with the given arguments, ignoring
Expand Down Expand Up @@ -573,12 +591,23 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
}

// getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
func getIptablesCommand(proto Protocol) string {
if proto == ProtocolIPv6 {
return "ip6tables"
} else {
return "iptables"
func getIptablesCommand(proto Protocol, mode ModeType) string {
var cmd string
switch proto {
case ProtocolIPv4:
cmd = "iptables"
case ProtocolIPv6:
cmd = "ip6tables"
}
// Append a suffix to the command to get the correct binary,
// If the mode is auto (default), the suffix is not applied and the system default is used.
switch mode {
case ModeTypeNFTables:
cmd = fmt.Sprintf("%s-nft", cmd)
case ModeTypeLegacy:
cmd = fmt.Sprintf("%s-legacy", cmd)
}
return cmd
}

// Checks if iptables has the "-C" and "--wait" flag
Expand All @@ -589,7 +618,7 @@ func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool)
// getIptablesVersion returns the first three components of the iptables version
// and the operating mode (e.g. nf_tables or legacy)
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
func extractIptablesVersion(str string) (int, int, int, string, error) {
func extractIptablesVersion(str string) (int, int, int, ModeType, error) {
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
result := versionMatcher.FindStringSubmatch(str)
if result == nil {
Expand All @@ -611,9 +640,9 @@ func extractIptablesVersion(str string) (int, int, int, string, error) {
return 0, 0, 0, "", err
}

mode := "legacy"
mode := ModeTypeLegacy
if result[4] != "" {
mode = result[4]
mode = ModeType(result[4])
}
return v1, v2, v3, mode, nil
}
Expand Down
138 changes: 90 additions & 48 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ import (
"testing"
)

var (
protos = []Protocol{ProtocolIPv4, ProtocolIPv6}
modes = []ModeType{ModeTypeAuto, ModeTypeLegacy, ModeTypeNFTables}
)

// getProtoName returns the name of the protocol, for use in test names.
func getProtoName(proto Protocol) string {
switch proto {
case ProtocolIPv4:
return "IPv4"
case ProtocolIPv6:
return "IPv6"
default:
panic("unknown protocol")
}
}

func TestProto(t *testing.T) {
ipt, err := New()
if err != nil {
Expand All @@ -34,40 +51,72 @@ func TestProto(t *testing.T) {
t.Fatalf("Expected default protocol IPv4, got %v", ipt.Proto())
}

ip4t, err := NewWithProtocol(ProtocolIPv4)
if err != nil {
t.Fatalf("NewWithProtocol(ProtocolIPv4) failed: %v", err)
}
if ip4t.Proto() != ProtocolIPv4 {
t.Fatalf("Expected protocol IPv4, got %v", ip4t.Proto())
for _, proto := range protos {
protoName := getProtoName(proto)
ipt, err := New(IPFamily(proto))
if err != nil {
t.Fatalf("NewWithProtocol(%s) failed: %v", protoName, err)
}
if ipt.Proto() != proto {
t.Fatalf("Expected protocol %s, got %v", protoName, ipt.Proto())
}
if ipt.mode != ModeTypeAuto {
t.Fatalf("Expected mode auto, got %v", ipt.mode)
}
}

ip6t, err := NewWithProtocol(ProtocolIPv6)
if err != nil {
t.Fatalf("NewWithProtocol(ProtocolIPv6) failed: %v", err)
}
if ip6t.Proto() != ProtocolIPv6 {
t.Fatalf("Expected protocol IPv6, got %v", ip6t.Proto())
for _, proto := range protos {
for _, mode := range modes {
protoName := getProtoName(proto)
ipt, err := New(Mode(mode), IPFamily(proto))
if err != nil {
t.Fatalf("New(Mode(%v), IPFamily(%v)) failed: %v", mode, protoName, err)
}
if ipt.Proto() != proto {
t.Fatalf("Expected protocol %v, got %v", protoName, ipt.Proto())
}
if ipt.mode != mode {
t.Fatalf("Expected mode %v, got %v", mode, ipt.mode)
}
}
}
}

func TestTimeout(t *testing.T) {
ipt, err := New()
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt.timeout != 0 {
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
}
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt.timeout != 0 {
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
}

ipt2, err := New(Timeout(5))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt2.timeout != 5 {
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
ipt2, err := New(Timeout(5))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt2.timeout != 5 {
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
}
}
}
}

func TestGetIptablesVersionMode(t *testing.T) {
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
t.Fatalf("New failed: %v", err)
}
_, _, _, getmode := ipt.GetIptablesVersion()
if getmode != mode {
t.Fatalf("Expected mode %v, got %v", mode, mode)
}
}
}
}

func randChain(t *testing.T) string {
Expand All @@ -92,27 +141,20 @@ func contains(list []string, value string) bool {
// features enabled & disabled, to test compatibility.
// We used to test noWait as well, but that was removed as of iptables v1.6.0
func mustTestableIptables() []*IPTables {
ipt, err := New()
if err != nil {
panic(fmt.Sprintf("New failed: %v", err))
}
ip6t, err := NewWithProtocol(ProtocolIPv6)
if err != nil {
panic(fmt.Sprintf("NewWithProtocol(ProtocolIPv6) failed: %v", err))
}
ipts := []*IPTables{ipt, ip6t}

// ensure we check one variant without built-in checking
if ipt.hasCheck {
i := *ipt
i.hasCheck = false
ipts = append(ipts, &i)

i6 := *ip6t
i6.hasCheck = false
ipts = append(ipts, &i6)
} else {
panic("iptables on this machine is too old -- missing -C")
ipts := []*IPTables{}
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
panic(fmt.Sprintf("New(IPFamily(%v), Mode(%v)) failed: %v", proto, mode, err))
}
if ipt.hasCheck {
ipt.hasCheck = false
ipts = append(ipts, ipt)
} else {
panic("iptables on this machine is too old -- missing -C")
}
}
}
return ipts
}
Expand Down Expand Up @@ -251,7 +293,7 @@ func TestRules(t *testing.T) {
}

func runRulesTests(t *testing.T, ipt *IPTables) {
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck)
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto(), ModeTypeAuto), ipt.hasWait, ipt.hasCheck)

var address1, address2, subnet1, subnet2 string
if ipt.Proto() == ProtocolIPv6 {
Expand Down Expand Up @@ -689,7 +731,7 @@ func TestExtractIptablesVersion(t *testing.T) {
t.Fatalf("unexpected err %s", err)
}

if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != ModeType(tt.mode) {
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
tt.v1, tt.v2, tt.v3, tt.mode,
v1, v2, v3, mode)
Expand Down