From a61222d2f07f3aab0ad336e08b946a8c4c534d06 Mon Sep 17 00:00:00 2001 From: Lorenz Bauer Date: Mon, 22 Jul 2024 09:28:45 +0100 Subject: [PATCH] fix a variety of Copy() problems The are multiple Copy() methods in the code base which are used to create deep copies of a struct. Any bugs in them can manifest as data races in concurrent code. Add a quicktest checker which ensures that two variables are a deep copy of each other: values match, but all locations in memory differ. Fix a variety of problems with the existing Copy() implementations. They all allow copying a nil struct and copy all elements. There are exceptions to the deep copy rule: - MapSpec.Contents is not deep copied because we can't easily make copies of interface values. This is documented already. - MapSpec.Extra is an immutable bytes.Reader and therefore only needs a shallow copy. - ProgramSpec.AttachTarget is a Program, which is (currently) safe for concurrent use. Fixes #1517 Signed-off-by: Lorenz Bauer --- btf/btf.go | 22 ++-- btf/btf_test.go | 2 + collection.go | 2 +- collection_test.go | 26 ++--- internal/testutils/checkers.go | 158 ++++++++++++++++++++++++++++ internal/testutils/checkers_test.go | 106 +++++++++++++++++++ map.go | 22 +++- map_test.go | 24 +++++ prog_test.go | 21 ++++ 9 files changed, 353 insertions(+), 30 deletions(-) create mode 100644 internal/testutils/checkers.go create mode 100644 internal/testutils/checkers_test.go diff --git a/btf/btf.go b/btf/btf.go index 204757dbf..671f680b2 100644 --- a/btf/btf.go +++ b/btf/btf.go @@ -66,7 +66,7 @@ func (s *immutableTypes) typeByID(id TypeID) (Type, bool) { // mutableTypes is a set of types which may be changed. type mutableTypes struct { imm immutableTypes - mu *sync.RWMutex // protects copies below + mu sync.RWMutex // protects copies below copies map[Type]Type // map[orig]copy copiedTypeIDs map[Type]TypeID // map[copy]origID } @@ -94,10 +94,14 @@ func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type { } // copy a set of mutable types. -func (mt *mutableTypes) copy() mutableTypes { - mtCopy := mutableTypes{ +func (mt *mutableTypes) copy() *mutableTypes { + if mt == nil { + return nil + } + + mtCopy := &mutableTypes{ mt.imm, - &sync.RWMutex{}, + sync.RWMutex{}, make(map[Type]Type, len(mt.copies)), make(map[Type]TypeID, len(mt.copiedTypeIDs)), } @@ -169,7 +173,7 @@ func (mt *mutableTypes) anyTypesByName(name string) ([]Type, error) { // Spec allows querying a set of Types and loading the set into the // kernel. type Spec struct { - mutableTypes + *mutableTypes // String table from ELF. strings *stringTable @@ -339,7 +343,7 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error typeIDs, typesByName := indexTypes(types, firstTypeID) return &Spec{ - mutableTypes{ + &mutableTypes{ immutableTypes{ types, typeIDs, @@ -347,7 +351,7 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error typesByName, bo, }, - &sync.RWMutex{}, + sync.RWMutex{}, make(map[Type]Type), make(map[Type]TypeID), }, @@ -522,6 +526,10 @@ func fixupDatasecLayout(ds *Datasec) error { // Copy creates a copy of Spec. func (s *Spec) Copy() *Spec { + if s == nil { + return nil + } + return &Spec{ s.mutableTypes.copy(), s.strings, diff --git a/btf/btf_test.go b/btf/btf_test.go index 1b758f212..37fa8d033 100644 --- a/btf/btf_test.go +++ b/btf/btf_test.go @@ -326,6 +326,8 @@ func TestGuessBTFByteOrder(t *testing.T) { } func TestSpecCopy(t *testing.T) { + qt.Check(t, qt.IsNil((*Spec)(nil).Copy())) + spec := parseELFBTF(t, "../testdata/loader-el.elf") cpy := spec.Copy() diff --git a/collection.go b/collection.go index a5532220f..b2cb214ad 100644 --- a/collection.go +++ b/collection.go @@ -57,7 +57,7 @@ func (cs *CollectionSpec) Copy() *CollectionSpec { Maps: make(map[string]*MapSpec, len(cs.Maps)), Programs: make(map[string]*ProgramSpec, len(cs.Programs)), ByteOrder: cs.ByteOrder, - Types: cs.Types, + Types: cs.Types.Copy(), } for name, spec := range cs.Maps { diff --git a/collection_test.go b/collection_test.go index 4d6ea9589..c82c3acc9 100644 --- a/collection_test.go +++ b/collection_test.go @@ -1,6 +1,7 @@ package ebpf import ( + "encoding/binary" "errors" "fmt" "io" @@ -57,7 +58,7 @@ func TestCollectionSpecNotModified(t *testing.T) { func TestCollectionSpecCopy(t *testing.T) { cs := &CollectionSpec{ - Maps: map[string]*MapSpec{ + map[string]*MapSpec{ "my-map": { Type: Array, KeySize: 4, @@ -65,7 +66,7 @@ func TestCollectionSpecCopy(t *testing.T) { MaxEntries: 1, }, }, - Programs: map[string]*ProgramSpec{ + map[string]*ProgramSpec{ "test": { Type: SocketFilter, Instructions: asm.Instructions{ @@ -76,25 +77,12 @@ func TestCollectionSpecCopy(t *testing.T) { License: "MIT", }, }, - Types: &btf.Spec{}, + &btf.Spec{}, + binary.LittleEndian, } - cpy := cs.Copy() - if cpy == cs { - t.Error("Copy returned the same pointer") - } - - if cpy.Maps["my-map"] == cs.Maps["my-map"] { - t.Error("Copy returned same Maps") - } - - if cpy.Programs["test"] == cs.Programs["test"] { - t.Error("Copy returned same Programs") - } - - if cpy.Types != cs.Types { - t.Error("Copy returned different Types") - } + qt.Check(t, qt.IsNil((*CollectionSpec)(nil).Copy())) + qt.Assert(t, testutils.IsDeepCopy(cs.Copy(), cs)) } func TestCollectionSpecLoadCopy(t *testing.T) { diff --git a/internal/testutils/checkers.go b/internal/testutils/checkers.go new file mode 100644 index 000000000..312003fe5 --- /dev/null +++ b/internal/testutils/checkers.go @@ -0,0 +1,158 @@ +package testutils + +import ( + "bytes" + "fmt" + "reflect" + + "github.com/go-quicktest/qt" +) + +// IsDeepCopy checks that got is a deep copy of want. +// +// All primitive values must be equal, but pointers must be distinct. +// This is different from [reflect.DeepEqual] which will accept equal pointer values. +// That is, reflect.DeepEqual(a, a) is true, while IsDeepCopy(a, a) is false. +func IsDeepCopy[T any](got, want T) qt.Checker { + return &deepCopyChecker[T]{got, want, make(map[pair]struct{})} +} + +type pair struct { + got, want reflect.Value +} + +type deepCopyChecker[T any] struct { + got, want T + visited map[pair]struct{} +} + +func (dcc *deepCopyChecker[T]) Check(_ func(key string, value any)) error { + return dcc.check(reflect.ValueOf(dcc.got), reflect.ValueOf(dcc.want)) +} + +func (dcc *deepCopyChecker[T]) check(got, want reflect.Value) error { + switch want.Kind() { + case reflect.Interface: + return dcc.check(got.Elem(), want.Elem()) + + case reflect.Pointer: + if got.IsNil() && want.IsNil() { + return nil + } + + if got.IsNil() { + return fmt.Errorf("expected non-nil pointer") + } + + if want.IsNil() { + return fmt.Errorf("expected nil pointer") + } + + if got.UnsafePointer() == want.UnsafePointer() { + return fmt.Errorf("equal pointer values") + } + + switch want.Type() { + case reflect.TypeOf((*bytes.Reader)(nil)): + // bytes.Reader doesn't allow modifying it's contents, so we + // allow a shallow copy. + return nil + } + + if _, ok := dcc.visited[pair{got, want}]; ok { + // Deal with recursive types. + return nil + } + + dcc.visited[pair{got, want}] = struct{}{} + return dcc.check(got.Elem(), want.Elem()) + + case reflect.Slice: + if got.IsNil() && want.IsNil() { + return nil + } + + if got.IsNil() { + return fmt.Errorf("expected non-nil slice") + } + + if want.IsNil() { + return fmt.Errorf("expected nil slice") + } + + if got.Len() != want.Len() { + return fmt.Errorf("expected %d elements, got %d", want.Len(), got.Len()) + } + + if want.Len() == 0 { + return nil + } + + if got.UnsafePointer() == want.UnsafePointer() { + return fmt.Errorf("equal backing memory") + } + + fallthrough + + case reflect.Array: + for i := 0; i < want.Len(); i++ { + if err := dcc.check(got.Index(i), want.Index(i)); err != nil { + return fmt.Errorf("index %d: %w", i, err) + } + } + + return nil + + case reflect.Struct: + for i := 0; i < want.NumField(); i++ { + if err := dcc.check(got.Field(i), want.Field(i)); err != nil { + return fmt.Errorf("%q: %w", want.Type().Field(i).Name, err) + } + } + + return nil + + case reflect.Map: + if got.Len() != want.Len() { + return fmt.Errorf("expected %d items, got %d", want.Len(), got.Len()) + } + + if got.UnsafePointer() == want.UnsafePointer() { + return fmt.Errorf("maps are equal") + } + + iter := want.MapRange() + for iter.Next() { + key := iter.Key() + got := got.MapIndex(iter.Key()) + if !got.IsValid() { + return fmt.Errorf("key %v is missing", key) + } + + want := iter.Value() + if err := dcc.check(got, want); err != nil { + return fmt.Errorf("key %v: %w", key, err) + } + } + + return nil + + case reflect.Chan, reflect.UnsafePointer: + return fmt.Errorf("%s is not supported", want.Type()) + + default: + // Compare by value as usual. + if !got.Equal(want) { + return fmt.Errorf("%#v is not equal to %#v", got, want) + } + + return nil + } +} + +func (dcc *deepCopyChecker[T]) Args() []qt.Arg { + return []qt.Arg{ + {Name: "got", Value: dcc.got}, + {Name: "want", Value: dcc.want}, + } +} diff --git a/internal/testutils/checkers_test.go b/internal/testutils/checkers_test.go new file mode 100644 index 000000000..45d675cfb --- /dev/null +++ b/internal/testutils/checkers_test.go @@ -0,0 +1,106 @@ +package testutils + +import ( + "testing" + + "github.com/go-quicktest/qt" +) + +func TestIsDeepCopy(t *testing.T) { + type s struct { + basic int + array [1]*int + array0 [0]int + ptr *int + slice []*int + ifc any + m map[*int]*int + rec *s + } + + key := 1 + copy := func() *s { + v := &s{ + 0, + [...]*int{new(int)}, + [...]int{}, + new(int), + []*int{new(int)}, + new(int), + map[*int]*int{&key: new(int)}, + nil, + } + v.rec = v + return v + } + + a, b := copy(), copy() + qt.Check(t, qt.IsNil(IsDeepCopy(a, b).Check(nil))) + + a.basic++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"basic": .*`)) + + a = copy() + (*a.array[0])++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"array": index 0: .*`)) + + a = copy() + a.array[0] = nil + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"array": index 0: .*`)) + + a = copy() + a.array = b.array + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"array": index 0: .*`)) + + a = copy() + (*a.ptr)++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"ptr": .*`)) + + a = copy() + a.ptr = b.ptr + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"ptr": .*`)) + + a = copy() + (*a.slice[0])++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"slice": .*`)) + + a = copy() + a.slice[0] = nil + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"slice": .*`)) + + a = copy() + a.slice = nil + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"slice": .*`)) + + a = copy() + a.slice = b.slice + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"slice": .*`)) + + a = copy() + *(a.ifc.(*int))++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"ifc": .*`)) + + a = copy() + a.ifc = b.ifc + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"ifc": .*`)) + + a = copy() + a.rec = b.rec + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"rec": .*`)) + + a = copy() + a.m = b.m + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"m": .*`)) + + a = copy() + (*a.m[&key])++ + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"m": .*`)) + + a = copy() + a.m[new(int)] = new(int) + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"m": .*`)) + + a = copy() + delete(a.m, &key) + qt.Check(t, qt.ErrorMatches(IsDeepCopy(a, b).Check(nil), `"m": .*`)) +} diff --git a/map.go b/map.go index e48412cbf..7412e596c 100644 --- a/map.go +++ b/map.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "reflect" + "slices" "strings" "sync" "time" @@ -96,11 +97,26 @@ func (ms *MapSpec) Copy() *MapSpec { } cpy := *ms + cpy.Contents = slices.Clone(cpy.Contents) - cpy.Contents = make([]MapKV, len(ms.Contents)) - copy(cpy.Contents, ms.Contents) + if cpy.InnerMap == ms { + cpy.InnerMap = &cpy + } else { + cpy.InnerMap = ms.InnerMap.Copy() + } + + if cpy.Extra != nil { + extra := *cpy.Extra + cpy.Extra = &extra + } - cpy.InnerMap = ms.InnerMap.Copy() + if cpy.Key != nil { + cpy.Key = btf.Copy(cpy.Key) + } + + if cpy.Value != nil { + cpy.Value = btf.Copy(cpy.Value) + } return &cpy } diff --git a/map_test.go b/map_test.go index 8866ac40f..39a7ae28d 100644 --- a/map_test.go +++ b/map_test.go @@ -1,6 +1,7 @@ package ebpf import ( + "bytes" "errors" "fmt" "math" @@ -92,6 +93,29 @@ func TestMap(t *testing.T) { } } +func TestMapSpecCopy(t *testing.T) { + a := &MapSpec{ + "foo", + Hash, + 4, + 4, + 1, + 1, + PinByName, + 1, + []MapKV{{1, 2}}, // Can't copy Contents, use value types + true, + nil, // InnerMap + bytes.NewReader(nil), + &btf.Int{}, + &btf.Int{}, + } + a.InnerMap = a + + qt.Check(t, qt.IsNil((*MapSpec)(nil).Copy())) + qt.Assert(t, testutils.IsDeepCopy(a.Copy(), a)) +} + func TestMapBatch(t *testing.T) { if err := haveBatchAPI(); err != nil { t.Skipf("batch api not available: %v", err) diff --git a/prog_test.go b/prog_test.go index 588ace880..dd540e5aa 100644 --- a/prog_test.go +++ b/prog_test.go @@ -711,6 +711,27 @@ func TestProgramRejectIncorrectByteOrder(t *testing.T) { } } +func TestProgramSpecCopy(t *testing.T) { + a := &ProgramSpec{ + "test", + 1, + 1, + "attach", + nil, // Can't copy Program + "section", + asm.Instructions{ + asm.Return(), + }, + 1, + "license", + 1, + binary.LittleEndian, + } + + qt.Check(t, qt.IsNil((*ProgramSpec)(nil).Copy())) + qt.Assert(t, testutils.IsDeepCopy(a.Copy(), a)) +} + func TestProgramSpecTag(t *testing.T) { arr := createArray(t)