From a0f98b3650c823666cecd9033acf3d97f75b1e3a Mon Sep 17 00:00:00 2001 From: Denise Li Date: Wed, 5 Jun 2024 17:53:27 -0400 Subject: [PATCH] add map key validation --- go-runtime/ftl/ftltest/fake.go | 18 ++++++++++++++++-- .../ftl/testdata/go/mapper/mapper_test.go | 4 ++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/go-runtime/ftl/ftltest/fake.go b/go-runtime/ftl/ftltest/fake.go index fe525da50b..3f4c16d74e 100644 --- a/go-runtime/ftl/ftltest/fake.go +++ b/go-runtime/ftl/ftltest/fake.go @@ -2,7 +2,9 @@ package ftltest import ( "context" + "fmt" "reflect" + "strings" "github.com/TBD54566975/ftl/go-runtime/internal" "github.com/TBD54566975/ftl/internal/modulecontext" @@ -31,12 +33,12 @@ func (f *fakeFTL) PublishEvent(ctx context.Context, topic string, event any) err } func (f *fakeFTL) addMapMock(ctx context.Context, mapper any, mockMap func(context.Context) (any, error)) { - key := reflect.ValueOf(mapper).Pointer() + key := makeMapKey(mapper) f.mockMaps[key] = mockMap } func (f *fakeFTL) CallMap(ctx context.Context, mapper any, mapImpl func(context.Context) (any, error)) any { - key := reflect.ValueOf(mapper).Pointer() + key := makeMapKey(mapper) mockMap, ok := f.mockMaps[key] if ok { return actuallyCallMap(ctx, mockMap) @@ -48,6 +50,18 @@ func (f *fakeFTL) CallMap(ctx context.Context, mapper any, mapImpl func(context. return nil } +func makeMapKey(mapper any) uintptr { + v := reflect.ValueOf(mapper) + if v.Kind() != reflect.Pointer { + panic("fakeFTL received object that was not a pointer, expected *MapHandle") + } + underlying := v.Elem().Type().Name() + if !strings.HasPrefix(underlying, "MapHandle[") { + panic(fmt.Sprintf("fakeFTL received *%s, expected *MapHandle", underlying)) + } + return v.Pointer() +} + func actuallyCallMap(ctx context.Context, impl modulecontext.Map) any { out, err := impl(ctx) if err != nil { diff --git a/go-runtime/ftl/testdata/go/mapper/mapper_test.go b/go-runtime/ftl/testdata/go/mapper/mapper_test.go index e439b852d4..09d4c04409 100644 --- a/go-runtime/ftl/testdata/go/mapper/mapper_test.go +++ b/go-runtime/ftl/testdata/go/mapper/mapper_test.go @@ -8,7 +8,7 @@ import ( "github.com/alecthomas/assert/v2" ) -/*func TestGet(t *testing.T) { +func TestGet(t *testing.T) { ctx := ftltest.Context(ftltest.WithMapsAllowed()) got := m.Get(ctx) assert.Equal(t, got, 9) @@ -19,7 +19,7 @@ func TestPanicsWithoutExplicitlyAllowingMaps(t *testing.T) { assert.Panics(t, func() { m.Get(ctx) }) - }*/ +} func TestMockGet(t *testing.T) { mockOut := 123