diff --git a/nftables.go b/nftables.go index c8722f3d..a261b8e5 100644 --- a/nftables.go +++ b/nftables.go @@ -3,6 +3,7 @@ package main import ( + "fmt" "net" "strings" "time" @@ -10,22 +11,22 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/google/nftables" "github.com/google/nftables/expr" - "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" ) const defaultTimeout = 4 * time.Hour type nft struct { - conn *nftables.Conn - conn6 *nftables.Conn - set *nftables.Set - set6 *nftables.Set - table *nftables.Table - table6 *nftables.Table - DenyAction string - DenyLog bool - DenyLogPrefix string + conn *nftables.Conn + conn6 *nftables.Conn + set *nftables.Set + set6 *nftables.Set + table *nftables.Table + table6 *nftables.Table + DenyAction string + DenyLog bool + DenyLogPrefix string BlacklistsIpv4 string BlacklistsIpv6 string } @@ -61,9 +62,11 @@ func (n *nft) Init() error { Priority: nftables.ChainPriorityFilter, }) set := &nftables.Set{ - Name: n.BlacklistsIpv4, - Table: n.table, - KeyType: nftables.TypeIPAddr, + Name: n.BlacklistsIpv4, + Table: n.table, + KeyType: nftables.TypeIPAddr, + HasTimeout: true, + Interval: true, } if err := n.conn.AddSet(set, []nftables.SetElement{}); err != nil { @@ -129,9 +132,11 @@ func (n *nft) Init() error { Priority: nftables.ChainPriorityFilter, }) set := &nftables.Set{ - Name: n.BlacklistsIpv6, - Table: n.table6, - KeyType: nftables.TypeIP6Addr, + Name: n.BlacklistsIpv6, + Table: n.table6, + KeyType: nftables.TypeIP6Addr, + HasTimeout: true, + Interval: true, } if err := n.conn6.AddSet(set, []nftables.SetElement{}); err != nil { @@ -192,9 +197,27 @@ func (n *nft) Add(decision *models.Decision) error { log.Errorf("unable to parse timeout '%s' for '%s' : %s", *decision.Duration, *decision.Value, err) timeout = defaultTimeout } + var cidr string if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn.SetAddElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16()), Timeout: timeout}}); err != nil { + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/128", *decision.Value) + } else { + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } + if err := n.conn6.SetAddElements(n.set6, + []nftables.SetElement{ + {Key: []byte(cidrNet.IP.To16()), Timeout: timeout}, + {Key: []byte(incrementIP(bca).To16()), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn6.Flush(); err != nil { @@ -205,13 +228,24 @@ func (n *nft) Add(decision *models.Decision) error { return nil } } else { // ipv4 - var ipAddr string - if strings.Contains(*decision.Value, "/") { - ipAddr = strings.Split(*decision.Value, "/")[0] + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/32", *decision.Value) } else { - ipAddr = *decision.Value + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err } - if err := n.conn.SetAddElements(n.set, []nftables.SetElement{{Key: []byte(net.ParseIP(ipAddr).To4())}}); err != nil { + if err := n.conn.SetAddElements(n.set, + []nftables.SetElement{ + {Key: cidrNet.IP, Timeout: timeout}, + {Key: incrementIP(bca), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn.Flush(); err != nil { @@ -223,26 +257,57 @@ func (n *nft) Add(decision *models.Decision) error { } func (n *nft) Delete(decision *models.Decision) error { + var cidr string if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn.SetDeleteElements(n.set, []nftables.SetElement{{Key: net.ParseIP(*decision.Value).To16()}}); err != nil { + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/128", *decision.Value) + } else { + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } + if err := n.conn6.SetDeleteElements(n.set6, + []nftables.SetElement{ + {Key: []byte(cidrNet.IP.To16())}, + {Key: []byte(incrementIP(bca).To16()), IntervalEnd: true}, + }); err != nil { return err } - if err := n.conn.Flush(); err != nil { + if err := n.conn6.Flush(); err != nil { return err } + } else { - log.Debugf("not adding '%s' because ipv6 is disabled", *decision.Value) + log.Debugf("not removing '%s' because ipv6 is disabled", *decision.Value) return nil } } else { // ipv4 - var ipAddr string - if strings.Contains(*decision.Value, "/") { - ipAddr = strings.Split(*decision.Value, "/")[0] + var cidr string + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/32", *decision.Value) } else { - ipAddr = *decision.Value + cidr = *decision.Value } - if err := n.conn.SetDeleteElements(n.set, []nftables.SetElement{{Key: net.ParseIP(ipAddr).To4()}}); err != nil { + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } + if err := n.conn.SetDeleteElements(n.set, + []nftables.SetElement{ + {Key: cidrNet.IP}, + {Key: incrementIP(bca), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn.Flush(); err != nil { @@ -269,3 +334,47 @@ func (n *nft) ShutDown() error { } return nil } + +// Utilites from https://github.com/IBM/netaddr/blob/master/net_utils.go + +// NewIP returns a new IP with the given size. The size must be 4 for IPv4 and +// 16 for IPv6. +func NewIP(size int) (net.IP, error) { + if size == 4 { + return net.ParseIP("0.0.0.0").To4(), nil + } + if size == 16 { + return net.ParseIP("::"), nil + } + return net.IP{}, fmt.Errorf("invalid size %d", size) +} + +// BroadcastAddr returns the last address in the given network, or the broadcast address. +func BroadcastAddr(n *net.IPNet) (net.IP, error) { + // The golang net package doesn't make it easy to calculate the broadcast address. :( + broadcast, err := NewIP(len(n.IP)) + if err != nil { + return net.IP{}, err + } + for i := 0; i < len(n.IP); i++ { + broadcast[i] = n.IP[i] | ^n.Mask[i] + } + return broadcast, nil +} + +// incrementIP returns the given IP + 1 +func incrementIP(ip net.IP) (result net.IP) { + result = make([]byte, len(ip)) // start off with a nice empty ip of proper length + + carry := true + for i := len(ip) - 1; i >= 0; i-- { + result[i] = ip[i] + if carry { + result[i]++ + if result[i] != 0 { + carry = false + } + } + } + return +}