Skip to content

Commit

Permalink
[client] Add stateful userspace firewall and remove egress filters (#…
Browse files Browse the repository at this point in the history
…3093)

- Add stateful firewall functionality for UDP/TCP/ICMP in userspace firewalll
- Removes all egress drop rules/filters, still needs refactoring so we don't add output rules to any chains/filters.
- on Linux, if the OUTPUT policy is DROP  then we don't do anything about it (no extra allow rules). This is up to the user, if they don't want anything leaving their machine they'll have to manage these rules explicitly.
  • Loading branch information
lixmal authored Dec 23, 2024
1 parent 05930ee commit ad9f044
Show file tree
Hide file tree
Showing 16 changed files with 3,099 additions and 115 deletions.
6 changes: 0 additions & 6 deletions client/firewall/iptables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error {

// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {

established := getConntrackEstablished()

m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))

m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))

m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
Expand Down
14 changes: 2 additions & 12 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error {
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
return nil
}

// Flush doesn't need to be implemented for this manager
Expand Down
53 changes: 0 additions & 53 deletions client/firewall/nftables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"
Expand All @@ -28,7 +27,6 @@ const (

// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"

Expand Down Expand Up @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) {
return err
}

// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}

// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
Expand Down Expand Up @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}

func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}

func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Expand Down
20 changes: 19 additions & 1 deletion client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

package uspfilter

import "github.com/netbirdio/netbird/client/internal/statemanager"
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
Expand All @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}
Expand Down
16 changes: 16 additions & 0 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

Expand All @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if !isWindowsFirewallReachable() {
return nil
}
Expand Down
138 changes: 138 additions & 0 deletions client/firewall/uspfilter/conntrack/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// common.go
package conntrack

import (
"net"
"sync"
"sync/atomic"
"time"
)

// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}

// these small methods will be inlined by the compiler

// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}

// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}

// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}

// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}

// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}

// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte

// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}

// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}

// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}

// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}

// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}

// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}

// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}

// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}

// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}
Loading

0 comments on commit ad9f044

Please sign in to comment.