From c6c3fd0490ad73af1f4775d84de18c4ba8fb7af0 Mon Sep 17 00:00:00 2001 From: Damon Date: Thu, 20 Jun 2024 20:21:02 -0700 Subject: [PATCH] Handle MultimapKeysSideInput in State GetRequests (#31632) * Handle MultimapKeysSideInput in State GetRequests * Assign data to keys * Fix test name * Fix import sort --- .../runners/prism/internal/worker/worker.go | 15 +++ .../prism/internal/worker/worker_test.go | 95 +++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 47fc2cccfc54..d8eb4c961493 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -468,6 +468,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { data = winMap[w] + case *fnpb.StateKey_MultimapKeysSideInput_: + mmkey := key.GetMultimapKeysSideInput() + wKey := mmkey.GetWindow() + var w typex.Window = window.GlobalWindow{} + if len(wKey) > 0 { + w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err)) + } + } + winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}] + for k := range winMap[w] { + data = append(data, []byte(k)) + } + case *fnpb.StateKey_MultimapSideInput_: mmkey := key.GetMultimapSideInput() wKey := mmkey.GetWindow() diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index b87667eef387..e5b03214ae0f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -18,12 +18,16 @@ package worker import ( "bytes" "context" + "github.com/google/go-cmp/cmp" "net" + "sort" "sync" "testing" "time" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" @@ -97,6 +101,23 @@ func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { return ctx, w, clientConn } +type closeSend func() + +func serveTestWorkerStateStream(t *testing.T) (*W, fnpb.BeamFnState_StateClient, closeSend) { + ctx, wk, clientConn := serveTestWorker(t) + + stateCli := fnpb.NewBeamFnStateClient(clientConn) + stateStream, err := stateCli.State(ctx) + if err != nil { + t.Fatal("couldn't create state client:", err) + } + return wk, stateStream, func() { + if err := stateStream.CloseSend(); err != nil { + t.Errorf("stateStream.CloseSend() = %v", err) + } + } +} + func TestWorker_Logging(t *testing.T) { ctx, _, clientConn := serveTestWorker(t) @@ -291,3 +312,77 @@ func TestWorker_State_Iterable(t *testing.T) { t.Errorf("stateStream.CloseSend() = %v", err) } } + +func TestWorker_State_MultimapKeysSideInput(t *testing.T) { + for _, tt := range []struct { + name string + w typex.Window + }{ + { + name: "global window", + w: window.GlobalWindow{}, + }, + { + name: "interval window", + w: window.IntervalWindow{ + Start: 1000, + End: 2000, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var encW []byte + if !tt.w.Equals(window.GlobalWindow{}) { + buf := bytes.Buffer{} + if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { + t.Fatalf("error encoding window: %v, err: %v", tt.w, err) + } + encW = buf.Bytes() + } + wk, stateStream, done := serveTestWorkerStateStream(t) + defer done() + instID := wk.NextInst() + wk.activeInstructions[instID] = &B{ + MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ + SideInputKey{ + TransformID: "transformID", + Local: "i1", + }: { + tt.w: map[string][][]byte{"a": {{1}}, "b": {{2}}}, + }, + }, + } + + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapKeysSideInput_{ + MultimapKeysSideInput: &fnpb.StateKey_MultimapKeysSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: encW, + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("couldn't receive state response:", err) + } + + want := []int{97, 98} + var got []int + for _, b := range resp.GetGet().GetData() { + got = append(got, int(b)) + } + sort.Ints(got) + + if !cmp.Equal(got, want) { + t.Errorf("didn't receive expected state response data: got %v, want %v", got, want) + } + }) + } +}