diff --git a/raw_socket_listener/listener.go b/raw_socket_listener/listener.go index f639254b..d225762b 100644 --- a/raw_socket_listener/listener.go +++ b/raw_socket_listener/listener.go @@ -254,6 +254,28 @@ func (e *DeviceNotFoundError) Error() string { return msg } +func isLoopback(device pcap.Interface) bool { + if len(device.Addresses) == 0 { + return false + } + + switch device.Addresses[0].IP.String() { + case "127.0.0.1", "::1": + return true + } + + return false +} + +func listenAllInterfaces(addr string) bool { + switch addr { + case "", "0.0.0.0", "[::]", "::": + return true + default: + return false + } +} + func findPcapDevices(addr string) (interfaces []pcap.Interface, err error) { devices, err := pcap.FindAllDevs() if err != nil { @@ -261,7 +283,7 @@ func findPcapDevices(addr string) (interfaces []pcap.Interface, err error) { } for _, device := range devices { - if (addr == "" || addr == "0.0.0.0" || addr == "[::]" || addr == "::") && len(device.Addresses) > 0 { + if listenAllInterfaces(addr) && len(device.Addresses) > 0 || isLoopback(device) { interfaces = append(interfaces, device) continue } @@ -309,12 +331,26 @@ func (t *Listener) readPcap() { t.pcapHandles = append(t.pcapHandles, handle) var bpfDstHost, bpfSrcHost string - for i, addr := range device.Addresses { - bpfDstHost += "dst host " + addr.IP.String() - bpfSrcHost += "src host " + addr.IP.String() - if i != len(device.Addresses)-1 { - bpfDstHost += " or " - bpfSrcHost += " or " + var loopback = isLoopback(device) + + if loopback { + var allAddr []string + for _, dc := range devices { + for _, addr := range dc.Addresses { + allAddr = append(allAddr, "(dst host " + addr.IP.String() + " and src host " + addr.IP.String() + ")") + } + } + + bpfDstHost = strings.Join(allAddr, " or ") + bpfSrcHost = bpfDstHost + } else { + for i, addr := range device.Addresses { + bpfDstHost += "dst host " + addr.IP.String() + bpfSrcHost += "src host " + addr.IP.String() + if i != len(device.Addresses)-1 { + bpfDstHost += " or " + bpfSrcHost += " or " + } } } @@ -439,10 +475,26 @@ func (t *Listener) readPcap() { } addrMatched := false - for _, a := range device.Addresses { - if a.IP.Equal(net.IP(addrCheck)) { - addrMatched = true - break + + if loopback { + for _, dc := range devices { + if addrMatched { + break + } + for _, a := range dc.Addresses { + if a.IP.Equal(net.IP(addrCheck)) { + addrMatched = true + break + } + } + } + addrMatched = true + } else { + for _, a := range device.Addresses { + if a.IP.Equal(net.IP(addrCheck)) { + addrMatched = true + break + } } }