Skip to content

Commit

Permalink
[ebpfless] Cancel packet capture loop via Close() (#33056)
Browse files Browse the repository at this point in the history
  • Loading branch information
pimlu authored Jan 23, 2025
1 parent 5d32b26 commit f8867cf
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 41 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ replace github.com/ProtonMail/go-crypto => github.com/ProtonMail/go-crypto v1.0.
// Prevent a false-positive detection by the Google and Ikarus security vendors on VirusTotal
exclude go.opentelemetry.io/proto/otlp v1.1.0

replace github.com/google/gopacket v1.1.19 => github.com/DataDog/gopacket v0.0.0-20240626205202-4ac4cee31f14
replace github.com/google/gopacket v1.1.19 => github.com/DataDog/gopacket v0.0.0-20250121143817-e1e3480abefb

// Remove once https://github.com/kubernetes/kube-state-metrics/pull/2553 is merged
replace k8s.io/kube-state-metrics/v2 v2.13.1-0.20241025121156-110f03d7331f => github.com/L3n41c/kube-state-metrics/v2 v2.13.1-0.20241119155242-07761b9fe9a0
Expand Down
4 changes: 2 additions & 2 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 23 additions & 10 deletions pkg/network/dns/packet_source_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package dns

import (
"sync"
"time"

"github.com/google/gopacket"
Expand All @@ -20,7 +21,9 @@ import (
var _ filter.PacketSource = &windowsPacketSource{}

type windowsPacketSource struct {
di *dnsDriver
di *dnsDriver
exit chan struct{}
mu sync.Mutex
}

// newWindowsPacketSource constructs a new packet source
Expand All @@ -29,25 +32,30 @@ func newWindowsPacketSource(telemetrycomp telemetry.Component) (filter.PacketSou
if err != nil {
return nil, err
}
return &windowsPacketSource{di: di}, nil
return &windowsPacketSource{
di: di,
exit: make(chan struct{}),
}, nil
}

func (p *windowsPacketSource) VisitPackets(exit <-chan struct{}, visit func([]byte, filter.PacketInfo, time.Time) error) error {
func (p *windowsPacketSource) VisitPackets(visit func([]byte, filter.PacketInfo, time.Time) error) error {
p.mu.Lock()
defer p.mu.Unlock()
for {
// break out of loop if exit is closed
select {
case <-p.exit:
return nil
default:
}

didReadPacket, err := p.di.ReadDNSPacket(visit)
if err != nil {
return err
}
if !didReadPacket {
return nil
}

// break out of loop if exit is closed
select {
case <-exit:
return nil
default:
}
}
}

Expand All @@ -56,5 +64,10 @@ func (p *windowsPacketSource) LayerType() gopacket.LayerType {
}

func (p *windowsPacketSource) Close() {
close(p.exit)

// wait for the VisitPackets loop to finish, then close
p.mu.Lock()
defer p.mu.Unlock()
_ = p.di.Close()
}
5 changes: 3 additions & 2 deletions pkg/network/dns/snooper.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ func (s *socketFilterSnooper) Start() error {
func (s *socketFilterSnooper) Close() {
s.once.Do(func() {
close(s.exit)
s.wg.Wait()
// close the packet capture loop and wait for it to finish
s.source.Close()
s.wg.Wait()
s.cache.Close()
if s.statKeeper != nil {
s.statKeeper.Close()
Expand Down Expand Up @@ -170,7 +171,7 @@ func (s *socketFilterSnooper) processPacket(data []byte, _ filter.PacketInfo, ts

func (s *socketFilterSnooper) pollPackets() {
for {
err := s.source.VisitPackets(s.exit, s.processPacket)
err := s.source.VisitPackets(s.processPacket)

if err != nil {
log.Warnf("error reading packet: %s", err)
Expand Down
7 changes: 4 additions & 3 deletions pkg/network/filter/packet_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ type PacketSource interface {
// If no packet is available, VisitPacket blocks until OptPollTimeout and returns.
// The format of the packet is dependent on the implementation of PacketSource -- i.e. it may be an ethernet frame, or a IP frame.
// The data buffer is reused between invocations of VisitPacket and thus should not be pointed to.
// If the cancel channel is closed, VisitPackets will stop reading.
VisitPackets(cancel <-chan struct{}, visitor func(data []byte, info PacketInfo, timestamp time.Time) error) error
// If the PacketSource is closed, VisitPackets will stop reading.
VisitPackets(visitor func(data []byte, info PacketInfo, timestamp time.Time) error) error

// LayerType returns the type of packet this source reads
LayerType() gopacket.LayerType

// Close closes the packet source
// Close closes the packet source. This will cancel VisitPackets if it is currently polling.
// Close() will not return until after VisitPackets has been canceled/returned.
Close()
}
15 changes: 4 additions & 11 deletions pkg/network/filter/packet_source_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,17 @@ type zeroCopyPacketReader interface {
// The data buffer is reused between calls, so be careful
type AFPacketVisitor = func(data []byte, info PacketInfo, t time.Time) error

func visitPackets(p zeroCopyPacketReader, exit <-chan struct{}, visit AFPacketVisitor) error {
func visitPackets(p zeroCopyPacketReader, visit AFPacketVisitor) error {
pktInfo := p.GetPacketInfoBuffer()
for {
// allow the read loop to be prematurely interrupted
select {
case <-exit:
return nil
default:
}

data, stats, err := p.ZeroCopyReadPacketData()

// Immediately retry for EAGAIN
if err == syscall.EAGAIN {
continue
}

if err == afpacket.ErrTimeout {
if err == afpacket.ErrTimeout || err == afpacket.ErrCancelled {
return nil
}

Expand All @@ -187,8 +180,8 @@ func visitPackets(p zeroCopyPacketReader, exit <-chan struct{}, visit AFPacketVi
}

// VisitPackets starts reading packets from the source
func (p *AFPacketSource) VisitPackets(exit <-chan struct{}, visit AFPacketVisitor) error {
return visitPackets(p, exit, visit)
func (p *AFPacketSource) VisitPackets(visit AFPacketVisitor) error {
return visitPackets(p, visit)
}

// LayerType is the gopacket.LayerType for this source
Expand Down
26 changes: 18 additions & 8 deletions pkg/network/filter/packet_source_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ import (
"golang.org/x/sys/unix"
)

type mockPacketReader struct {
type mockPacketCapture struct {
data []byte
ci gopacket.CaptureInfo
err error
}

type mockPacketReader struct {
packets <-chan mockPacketCapture
}

func (m *mockPacketReader) ZeroCopyReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) {
return m.data, m.ci, m.err
capture, ok := <-m.packets
if !ok {
return nil, gopacket.CaptureInfo{}, afpacket.ErrCancelled
}
return capture.data, capture.ci, capture.err
}
func (m *mockPacketReader) GetPacketInfoBuffer() *AFPacketInfo {
return &AFPacketInfo{}
Expand All @@ -42,17 +50,21 @@ func mockCaptureInfo(ancillaryData []interface{}) gopacket.CaptureInfo {
}

func expectAncillaryPktType(t *testing.T, ancillaryData []interface{}, pktType uint8) {
exit := make(chan struct{})

p := mockPacketReader{
packets := make(chan mockPacketCapture, 1)
packets <- mockPacketCapture{
data: []byte{},
ci: mockCaptureInfo(ancillaryData),
err: nil,
}
close(packets)

p := mockPacketReader{
packets: packets,
}

visited := false

err := visitPackets(&p, exit, func(_ []byte, info PacketInfo, _ time.Time) error {
err := visitPackets(&p, func(_ []byte, info PacketInfo, _ time.Time) error {
// make sure the callback ran since it's responsible for the require call
visited = true

Expand All @@ -61,8 +73,6 @@ func expectAncillaryPktType(t *testing.T, ancillaryData []interface{}, pktType u
// use assert so that we close the exit channel on failure
assert.Equal(t, pktType, pktInfo.PktType)

// trigger exit so it only reads one packet
close(exit)
return nil
})
require.NoError(t, err)
Expand Down
7 changes: 3 additions & 4 deletions pkg/network/tracer/connection/ebpfless_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (t *ebpfLessTracer) Start(closeCallback func(*network.ConnectionStats)) err
parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &ip4, &ip6, &tcp, &udp)
parser.IgnoreUnsupported = true
for {
err := t.packetSrc.VisitPackets(t.exit, func(b []byte, info filter.PacketInfo, _ time.Time) error {
err := t.packetSrc.VisitPackets(func(b []byte, info filter.PacketInfo, _ time.Time) error {
if err := parser.DecodeLayers(b, &decoded); err != nil {
return fmt.Errorf("error decoding packet layers: %w", err)
}
Expand Down Expand Up @@ -357,10 +357,9 @@ func (t *ebpfLessTracer) Stop() {
}

close(t.exit)

// can't close packetSrc while it's still visiting, wait for it to finish
t.packetSrcBusy.Wait()
// close the packet capture loop and wait for it to finish
t.packetSrc.Close()
t.packetSrcBusy.Wait()

t.ns.Close()
t.boundPorts.Stop()
Expand Down

0 comments on commit f8867cf

Please sign in to comment.