Skip to content

Commit

Permalink
Handle MultimapKeysSideInput in State GetRequests (#31632)
Browse files Browse the repository at this point in the history
* Handle MultimapKeysSideInput in State GetRequests

* Assign data to keys

* Fix test name

* Fix import sort
  • Loading branch information
damondouglas authored Jun 21, 2024
1 parent e4a0208 commit c6c3fd0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error {

data = winMap[w]

case *fnpb.StateKey_MultimapKeysSideInput_:
mmkey := key.GetMultimapKeysSideInput()
wKey := mmkey.GetWindow()
var w typex.Window = window.GlobalWindow{}
if len(wKey) > 0 {
w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
if err != nil {
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err))
}
}
winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}]
for k := range winMap[w] {
data = append(data, []byte(k))
}

case *fnpb.StateKey_MultimapSideInput_:
mmkey := key.GetMultimapSideInput()
wKey := mmkey.GetWindow()
Expand Down
95 changes: 95 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ package worker
import (
"bytes"
"context"
"github.com/google/go-cmp/cmp"
"net"
"sort"
"sync"
"testing"
"time"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
Expand Down Expand Up @@ -97,6 +101,23 @@ func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) {
return ctx, w, clientConn
}

type closeSend func()

func serveTestWorkerStateStream(t *testing.T) (*W, fnpb.BeamFnState_StateClient, closeSend) {
ctx, wk, clientConn := serveTestWorker(t)

stateCli := fnpb.NewBeamFnStateClient(clientConn)
stateStream, err := stateCli.State(ctx)
if err != nil {
t.Fatal("couldn't create state client:", err)
}
return wk, stateStream, func() {
if err := stateStream.CloseSend(); err != nil {
t.Errorf("stateStream.CloseSend() = %v", err)
}
}
}

func TestWorker_Logging(t *testing.T) {
ctx, _, clientConn := serveTestWorker(t)

Expand Down Expand Up @@ -291,3 +312,77 @@ func TestWorker_State_Iterable(t *testing.T) {
t.Errorf("stateStream.CloseSend() = %v", err)
}
}

func TestWorker_State_MultimapKeysSideInput(t *testing.T) {
for _, tt := range []struct {
name string
w typex.Window
}{
{
name: "global window",
w: window.GlobalWindow{},
},
{
name: "interval window",
w: window.IntervalWindow{
Start: 1000,
End: 2000,
},
},
} {
t.Run(tt.name, func(t *testing.T) {
var encW []byte
if !tt.w.Equals(window.GlobalWindow{}) {
buf := bytes.Buffer{}
if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil {
t.Fatalf("error encoding window: %v, err: %v", tt.w, err)
}
encW = buf.Bytes()
}
wk, stateStream, done := serveTestWorkerStateStream(t)
defer done()
instID := wk.NextInst()
wk.activeInstructions[instID] = &B{
MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{
SideInputKey{
TransformID: "transformID",
Local: "i1",
}: {
tt.w: map[string][][]byte{"a": {{1}}, "b": {{2}}},
},
},
}

stateStream.Send(&fnpb.StateRequest{
Id: "first",
InstructionId: instID,
Request: &fnpb.StateRequest_Get{
Get: &fnpb.StateGetRequest{},
},
StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapKeysSideInput_{
MultimapKeysSideInput: &fnpb.StateKey_MultimapKeysSideInput{
TransformId: "transformID",
SideInputId: "i1",
Window: encW,
},
}},
})

resp, err := stateStream.Recv()
if err != nil {
t.Fatal("couldn't receive state response:", err)
}

want := []int{97, 98}
var got []int
for _, b := range resp.GetGet().GetData() {
got = append(got, int(b))
}
sort.Ints(got)

if !cmp.Equal(got, want) {
t.Errorf("didn't receive expected state response data: got %v, want %v", got, want)
}
})
}
}

0 comments on commit c6c3fd0

Please sign in to comment.