Skip to content

Commit

Permalink
fix a variety of Copy() problems
Browse files Browse the repository at this point in the history
The are multiple Copy() methods in the code base which are used to
create deep copies of a struct. Any bugs in them can manifest as
data races in concurrent code.

Add a quicktest checker which ensures that two variables are a
deep copy of each other: values match, but all locations in memory
differ.

Fix a variety of problems with the existing Copy() implementations.
They all allow copying a nil struct and copy all elements.

There are exceptions to the deep copy rule:

- MapSpec.Contents is not deep copied because we can't easily make
  copies of interface values. This is documented already.
- MapSpec.Extra is an immutable bytes.Reader and therefore only
  needs a shallow copy.
- ProgramSpec.AttachTarget is a Program, which is (currently) safe
  for concurrent use.

Fixes #1517

Signed-off-by: Lorenz Bauer <[email protected]>
  • Loading branch information
lmb committed Jul 22, 2024
1 parent fbb9ed8 commit a61222d
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 30 deletions.
22 changes: 15 additions & 7 deletions btf/btf.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (s *immutableTypes) typeByID(id TypeID) (Type, bool) {
// mutableTypes is a set of types which may be changed.
type mutableTypes struct {
imm immutableTypes
mu *sync.RWMutex // protects copies below
mu sync.RWMutex // protects copies below
copies map[Type]Type // map[orig]copy
copiedTypeIDs map[Type]TypeID // map[copy]origID
}
Expand Down Expand Up @@ -94,10 +94,14 @@ func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type {
}

// copy a set of mutable types.
func (mt *mutableTypes) copy() mutableTypes {
mtCopy := mutableTypes{
func (mt *mutableTypes) copy() *mutableTypes {
if mt == nil {
return nil
}

mtCopy := &mutableTypes{
mt.imm,
&sync.RWMutex{},
sync.RWMutex{},
make(map[Type]Type, len(mt.copies)),
make(map[Type]TypeID, len(mt.copiedTypeIDs)),
}
Expand Down Expand Up @@ -169,7 +173,7 @@ func (mt *mutableTypes) anyTypesByName(name string) ([]Type, error) {
// Spec allows querying a set of Types and loading the set into the
// kernel.
type Spec struct {
mutableTypes
*mutableTypes

// String table from ELF.
strings *stringTable
Expand Down Expand Up @@ -339,15 +343,15 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error
typeIDs, typesByName := indexTypes(types, firstTypeID)

return &Spec{
mutableTypes{
&mutableTypes{
immutableTypes{
types,
typeIDs,
firstTypeID,
typesByName,
bo,
},
&sync.RWMutex{},
sync.RWMutex{},
make(map[Type]Type),
make(map[Type]TypeID),
},
Expand Down Expand Up @@ -522,6 +526,10 @@ func fixupDatasecLayout(ds *Datasec) error {

// Copy creates a copy of Spec.
func (s *Spec) Copy() *Spec {
if s == nil {
return nil
}

return &Spec{
s.mutableTypes.copy(),
s.strings,
Expand Down
2 changes: 2 additions & 0 deletions btf/btf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ func TestGuessBTFByteOrder(t *testing.T) {
}

func TestSpecCopy(t *testing.T) {
qt.Check(t, qt.IsNil((*Spec)(nil).Copy()))

spec := parseELFBTF(t, "../testdata/loader-el.elf")
cpy := spec.Copy()

Expand Down
2 changes: 1 addition & 1 deletion collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (cs *CollectionSpec) Copy() *CollectionSpec {
Maps: make(map[string]*MapSpec, len(cs.Maps)),
Programs: make(map[string]*ProgramSpec, len(cs.Programs)),
ByteOrder: cs.ByteOrder,
Types: cs.Types,
Types: cs.Types.Copy(),
}

for name, spec := range cs.Maps {
Expand Down
26 changes: 7 additions & 19 deletions collection_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ebpf

import (
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -57,15 +58,15 @@ func TestCollectionSpecNotModified(t *testing.T) {

func TestCollectionSpecCopy(t *testing.T) {
cs := &CollectionSpec{
Maps: map[string]*MapSpec{
map[string]*MapSpec{
"my-map": {
Type: Array,
KeySize: 4,
ValueSize: 4,
MaxEntries: 1,
},
},
Programs: map[string]*ProgramSpec{
map[string]*ProgramSpec{
"test": {
Type: SocketFilter,
Instructions: asm.Instructions{
Expand All @@ -76,25 +77,12 @@ func TestCollectionSpecCopy(t *testing.T) {
License: "MIT",
},
},
Types: &btf.Spec{},
&btf.Spec{},
binary.LittleEndian,
}
cpy := cs.Copy()

if cpy == cs {
t.Error("Copy returned the same pointer")
}

if cpy.Maps["my-map"] == cs.Maps["my-map"] {
t.Error("Copy returned same Maps")
}

if cpy.Programs["test"] == cs.Programs["test"] {
t.Error("Copy returned same Programs")
}

if cpy.Types != cs.Types {
t.Error("Copy returned different Types")
}
qt.Check(t, qt.IsNil((*CollectionSpec)(nil).Copy()))
qt.Assert(t, testutils.IsDeepCopy(cs.Copy(), cs))
}

func TestCollectionSpecLoadCopy(t *testing.T) {
Expand Down
158 changes: 158 additions & 0 deletions internal/testutils/checkers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package testutils

import (
"bytes"
"fmt"
"reflect"

"github.com/go-quicktest/qt"
)

// IsDeepCopy checks that got is a deep copy of want.
//
// All primitive values must be equal, but pointers must be distinct.
// This is different from [reflect.DeepEqual] which will accept equal pointer values.
// That is, reflect.DeepEqual(a, a) is true, while IsDeepCopy(a, a) is false.
func IsDeepCopy[T any](got, want T) qt.Checker {
return &deepCopyChecker[T]{got, want, make(map[pair]struct{})}
}

type pair struct {
got, want reflect.Value
}

type deepCopyChecker[T any] struct {
got, want T
visited map[pair]struct{}
}

func (dcc *deepCopyChecker[T]) Check(_ func(key string, value any)) error {
return dcc.check(reflect.ValueOf(dcc.got), reflect.ValueOf(dcc.want))
}

func (dcc *deepCopyChecker[T]) check(got, want reflect.Value) error {
switch want.Kind() {
case reflect.Interface:
return dcc.check(got.Elem(), want.Elem())

case reflect.Pointer:
if got.IsNil() && want.IsNil() {
return nil
}

if got.IsNil() {
return fmt.Errorf("expected non-nil pointer")
}

if want.IsNil() {
return fmt.Errorf("expected nil pointer")
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("equal pointer values")
}

switch want.Type() {
case reflect.TypeOf((*bytes.Reader)(nil)):
// bytes.Reader doesn't allow modifying it's contents, so we
// allow a shallow copy.
return nil
}

if _, ok := dcc.visited[pair{got, want}]; ok {
// Deal with recursive types.
return nil
}

dcc.visited[pair{got, want}] = struct{}{}
return dcc.check(got.Elem(), want.Elem())

case reflect.Slice:
if got.IsNil() && want.IsNil() {
return nil
}

if got.IsNil() {
return fmt.Errorf("expected non-nil slice")
}

if want.IsNil() {
return fmt.Errorf("expected nil slice")
}

if got.Len() != want.Len() {
return fmt.Errorf("expected %d elements, got %d", want.Len(), got.Len())
}

if want.Len() == 0 {
return nil
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("equal backing memory")
}

fallthrough

case reflect.Array:
for i := 0; i < want.Len(); i++ {
if err := dcc.check(got.Index(i), want.Index(i)); err != nil {
return fmt.Errorf("index %d: %w", i, err)
}
}

return nil

case reflect.Struct:
for i := 0; i < want.NumField(); i++ {
if err := dcc.check(got.Field(i), want.Field(i)); err != nil {
return fmt.Errorf("%q: %w", want.Type().Field(i).Name, err)
}
}

return nil

case reflect.Map:
if got.Len() != want.Len() {
return fmt.Errorf("expected %d items, got %d", want.Len(), got.Len())
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("maps are equal")
}

iter := want.MapRange()
for iter.Next() {
key := iter.Key()
got := got.MapIndex(iter.Key())
if !got.IsValid() {
return fmt.Errorf("key %v is missing", key)
}

want := iter.Value()
if err := dcc.check(got, want); err != nil {
return fmt.Errorf("key %v: %w", key, err)
}
}

return nil

case reflect.Chan, reflect.UnsafePointer:
return fmt.Errorf("%s is not supported", want.Type())

default:
// Compare by value as usual.
if !got.Equal(want) {
return fmt.Errorf("%#v is not equal to %#v", got, want)
}

return nil
}
}

func (dcc *deepCopyChecker[T]) Args() []qt.Arg {
return []qt.Arg{
{Name: "got", Value: dcc.got},
{Name: "want", Value: dcc.want},
}
}
Loading

0 comments on commit a61222d

Please sign in to comment.