diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go index fdc1e368a52b..71954819a748 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/data.go +++ b/sdks/go/pkg/beam/core/runtime/exec/data.go @@ -57,10 +57,12 @@ type SideCache interface { // DataManager manages external data byte streams. Each data stream can be // opened by one consumer only. type DataManager interface { - // OpenRead opens a closable byte stream for reading. - OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) - // OpenWrite opens a closable byte stream for writing. + // OpenElementChan opens a channel for data and timers. + OpenElementChan(ctx context.Context, id StreamID, expectedTimerTransforms []string) (<-chan Elements, error) + // OpenWrite opens a closable byte stream for data writing. OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) + // OpenTimerWrite opens a byte stream for writing timers + OpenTimerWrite(ctx context.Context, id StreamID, family string) (io.WriteCloser, error) } // StateReader is the interface for reading side input data. @@ -91,4 +93,10 @@ type StateReader interface { GetSideInputCache() SideCache } -// TODO(herohde) 7/20/2018: user state management +// Elements holds data or timers sent across the data channel. +// If TimerFamilyID is populated, it's a timer, otherwise it's +// data elements. +type Elements struct { + Data, Timers []byte + TimerFamilyID, PtransformID string +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index a6347fc8d0e1..66c081862f13 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -30,6 +30,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/ioutilx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "golang.org/x/exp/maps" ) // DataSource is a Root execution unit. @@ -40,9 +41,12 @@ type DataSource struct { Coder *coder.Coder Out Node PCol PCollection // Handles size metrics. Value instead of pointer so it's initialized by default in tests. + // OnTimerTransforms maps PtransformIDs to their execution nodes that handle OnTimer callbacks. + OnTimerTransforms map[string]*ParDo - source DataManager - state StateReader + source DataManager + state StateReader + curInst string index int64 splitIdx int64 @@ -94,20 +98,79 @@ func (n *DataSource) Up(ctx context.Context) error { // StartBundle initializes this datasource for the bundle. func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContext) error { n.mu.Lock() + n.curInst = id n.source = data.Data n.state = data.State n.start = time.Now() - n.index = -1 + n.index = 0 n.splitIdx = math.MaxInt64 n.mu.Unlock() return n.Out.StartBundle(ctx, id, data) } +// splitSuccess is a marker error to indicate we've reached the split index. +// Akin to io.EOF. +var splitSuccess = errors.New("split index reached") + +// process handles converting elements from the data source to timers. +// +// The data and timer callback functions must return an io.EOF if the reader terminates to signal that an additional +// buffer is desired. On successful splits, [splitSuccess] must be returned to indicate that the +// PTransform is done processing data for this instruction. +func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader, ptransformID string) error, timer func(bcr *byteCountReader, ptransformID, timerFamilyID string) error) error { + // The SID contains this instruction's expected data processing transform (this one). + elms, err := n.source.OpenElementChan(ctx, n.SID, maps.Keys(n.OnTimerTransforms)) + if err != nil { + return err + } + + n.PCol.resetSize() // initialize the size distribution for this bundle. + var r bytes.Reader + + var byteCount int + bcr := byteCountReader{reader: &r, count: &byteCount} + + splitPrimaryComplete := map[string]bool{} + for { + var err error + select { + case e, ok := <-elms: + // Channel closed, so time to exit + if !ok { + return nil + } + if splitPrimaryComplete[e.PtransformID] { + continue + } + if len(e.Data) > 0 { + r.Reset(e.Data) + err = data(&bcr, e.PtransformID) + } + if len(e.Timers) > 0 { + r.Reset(e.Timers) + err = timer(&bcr, e.PtransformID, e.TimerFamilyID) + } + + if err == splitSuccess { + // Returning splitSuccess means we've split, and aren't consuming the remaining buffer. + // We mark the PTransform done to ignore further data. + splitPrimaryComplete[e.PtransformID] = true + } else if err != nil && err != io.EOF { + return errors.Wrap(err, "source failed") + } + // io.EOF means the reader successfully drained. + // We're ready for a new buffer. + case <-ctx.Done(): + return nil + } + } +} + // ByteCountReader is a passthrough reader that counts all the bytes read through it. // It trusts the nested reader to return accurate byte information. type byteCountReader struct { count *int - reader io.ReadCloser + reader io.Reader } func (r *byteCountReader) Read(p []byte) (int, error) { @@ -117,7 +180,10 @@ func (r *byteCountReader) Read(p []byte) (int, error) { } func (r *byteCountReader) Close() error { - return r.reader.Close() + if c, ok := r.reader.(io.Closer); ok { + c.Close() + } + return nil } func (r *byteCountReader) reset() int { @@ -128,15 +194,6 @@ func (r *byteCountReader) reset() int { // Process opens the data source, reads and decodes data, kicking off element processing. func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { - r, err := n.source.OpenRead(ctx, n.SID) - if err != nil { - return nil, err - } - defer r.Close() - n.PCol.resetSize() // initialize the size distribution for this bundle. - var byteCount int - bcr := byteCountReader{reader: r, count: &byteCount} - c := coder.SkipW(n.Coder) wc := MakeWindowDecoder(n.Coder.Window) @@ -155,58 +212,63 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { } var checkpoints []*Checkpoint - for { - if n.incrementIndexAndCheckSplit() { - break - } - // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? - ws, t, pn, err := DecodeWindowedValueHeader(wc, r) - if err != nil { - if err == io.EOF { - break + err := n.process(ctx, func(bcr *byteCountReader, ptransformID string) error { + for { + // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? + ws, t, pn, err := DecodeWindowedValueHeader(wc, bcr.reader) + if err != nil { + return err } - return nil, errors.Wrap(err, "source failed") - } - - // Decode key or parallel element. - pe, err := cp.Decode(&bcr) - if err != nil { - return nil, errors.Wrap(err, "source decode failed") - } - pe.Timestamp = t - pe.Windows = ws - pe.Pane = pn - var valReStreams []ReStream - for _, cv := range cvs { - values, err := n.makeReStream(ctx, cv, &bcr, len(cvs) == 1 && n.singleIterate) + // Decode key or parallel element. + pe, err := cp.Decode(bcr) if err != nil { - return nil, err + return errors.Wrap(err, "source decode failed") } - valReStreams = append(valReStreams, values) - } + pe.Timestamp = t + pe.Windows = ws + pe.Pane = pn - if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { - return nil, err - } - // Collect the actual size of the element, and reset the bytecounter reader. - n.PCol.addSize(int64(bcr.reset())) - bcr.reader = r - - // Check if there's a continuation and return residuals - // Needs to be done immeadiately after processing to not lose the element. - if c := n.getProcessContinuation(); c != nil { - cp, err := n.checkpointThis(ctx, c) - if err != nil { - // Errors during checkpointing should fail a bundle. - return nil, err + var valReStreams []ReStream + for _, cv := range cvs { + values, err := n.makeReStream(ctx, cv, bcr, len(cvs) == 1 && n.singleIterate) + if err != nil { + return err + } + valReStreams = append(valReStreams, values) } - if cp != nil { - checkpoints = append(checkpoints, cp) + + if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { + return err + } + // Collect the actual size of the element, and reset the bytecounter reader. + n.PCol.addSize(int64(bcr.reset())) + + // Check if there's a continuation and return residuals + // Needs to be done immediately after processing to not lose the element. + if c := n.getProcessContinuation(); c != nil { + cp, err := n.checkpointThis(ctx, c) + if err != nil { + // Errors during checkpointing should fail a bundle. + return err + } + if cp != nil { + checkpoints = append(checkpoints, cp) + } + } + // We've finished processing an element, check if we have finished a split. + if n.incrementIndexAndCheckSplit() { + return splitSuccess } } - } - return checkpoints, nil + }, + func(bcr *byteCountReader, ptransformID, timerFamilyID string) error { + tmap, err := decodeTimer(cp, wc, bcr) + log.Infof(ctx, "DEBUGLOG: timer received for: %v and %v - %+v err: %v", ptransformID, timerFamilyID, tmap, err) + return nil + }) + + return checkpoints, err } func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *byteCountReader, onlyStream bool) (ReStream, error) { @@ -313,7 +375,7 @@ func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *b } } -func readStreamToBuffer(cv ElementDecoder, r io.ReadCloser, size int64, buf []FullValue) ([]FullValue, error) { +func readStreamToBuffer(cv ElementDecoder, r io.Reader, size int64, buf []FullValue) ([]FullValue, error) { for i := int64(0); i < size; i++ { value, err := cv.Decode(r) if err != nil { @@ -472,7 +534,7 @@ func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuat // The bufSize param specifies the estimated number of elements that will be // sent to this DataSource, and is used to be able to perform accurate splits // even if the DataSource has not yet received all its elements. A bufSize of -// 0 or less indicates that its unknown, and so uses the current known size. +// 0 or less indicates that it's unknown, and so uses the current known size. func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bufSize int64) (SplitResult, error) { if n == nil { return SplitResult{}, fmt.Errorf("failed to split at requested splits: {%v}, DataSource not initialized", splits) diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go index 2da3284f016a..ebede8538083 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go @@ -16,7 +16,6 @@ package exec import ( - "bytes" "context" "fmt" "io" @@ -43,20 +42,20 @@ func TestDataSource_PerElement(t *testing.T) { name string expected []any Coder *coder.Coder - driver func(*coder.Coder, io.WriteCloser, []any) + driver func(*coder.Coder, *chanWriter, []any) }{ { name: "perElement", expected: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, pw io.WriteCloser, expected []any) { + driver: func(c *coder.Coder, cw *chanWriter, expected []any) { wc := MakeWindowEncoder(c.Window) ec := MakeElementEncoder(coder.SkipW(c)) for _, v := range expected { - EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), pw) - ec.Encode(&FullValue{Elm: v}, pw) + EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), cw) + ec.Encode(&FullValue{Elm: v}, cw) } - pw.Close() + cw.Close() }, }, } @@ -70,11 +69,11 @@ func TestDataSource_PerElement(t *testing.T) { Coder: test.Coder, Out: out, } - pr, pw := io.Pipe() - go test.driver(source.Coder, pw, test.expected) + cw := makeChanWriter() + go test.driver(source.Coder, cw, test.expected) constructAndExecutePlanWithContext(t, []Unit{out, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: cw.Ch}, }) validateSource(t, out, source, makeValues(test.expected...)) @@ -98,14 +97,14 @@ func TestDataSource_Iterators(t *testing.T) { name string keys, vals []any Coder *coder.Coder - driver func(c *coder.Coder, dmw io.WriteCloser, siwFn func() io.WriteCloser, ks, vs []any) + driver func(c *coder.Coder, dmw *chanWriter, siwFn func() io.WriteCloser, ks, vs []any) }{ { name: "beam:coder:iterable:v1-singleChunk", keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, _ func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -123,7 +122,7 @@ func TestDataSource_Iterators(t *testing.T) { keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, _ func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -144,7 +143,7 @@ func TestDataSource_Iterators(t *testing.T) { keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, swFn func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, swFn func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -155,6 +154,8 @@ func TestDataSource_Iterators(t *testing.T) { token := []byte(tokenString) coder.EncodeVarInt(int64(len(token)), dmw) // token. dmw.Write(token) + dmw.Flush() // Flush here to allow state IO from this goroutine. + // Each state stream needs to be a different writer, so get a new writer. sw := swFn() for _, v := range vs { @@ -170,6 +171,7 @@ func TestDataSource_Iterators(t *testing.T) { for _, singleIterate := range []bool{true, false} { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + fmt.Println(test.name) capture := &IteratorCaptureNode{CaptureNode: CaptureNode{UID: 1}} out := Node(capture) units := []Unit{out} @@ -187,8 +189,7 @@ func TestDataSource_Iterators(t *testing.T) { Out: out, } units = append(units, source) - dmr, dmw := io.Pipe() - + cw := makeChanWriter() // Simulate individual state channels with pipes and a channel. sRc := make(chan io.ReadCloser) swFn := func() io.WriteCloser { @@ -196,10 +197,10 @@ func TestDataSource_Iterators(t *testing.T) { sRc <- sr return sw } - go test.driver(source.Coder, dmw, swFn, test.keys, test.vals) + go test.driver(source.Coder, cw, swFn, test.keys, test.vals) constructAndExecutePlanWithContext(t, units, DataContext{ - Data: &TestDataManager{R: dmr}, + Data: &TestDataManager{Ch: cw.Ch}, State: &TestStateReader{Rc: sRc}, }) if len(capture.CapturedInputs) == 0 { @@ -240,7 +241,7 @@ func TestDataSource_Iterators(t *testing.T) { func TestDataSource_Split(t *testing.T) { elements := []any{int64(1), int64(2), int64(3), int64(4), int64(5)} - initSourceTest := func(name string) (*DataSource, *CaptureNode, io.ReadCloser) { + initSourceTest := func(name string) (*DataSource, *CaptureNode, chan Elements) { out := &CaptureNode{UID: 1} c := coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow()) source := &DataSource{ @@ -250,7 +251,7 @@ func TestDataSource_Split(t *testing.T) { Coder: c, Out: out, } - pr, pw := io.Pipe() + cw := makeChanWriter() go func(c *coder.Coder, pw io.WriteCloser, elements []any) { wc := MakeWindowEncoder(c.Window) @@ -260,8 +261,8 @@ func TestDataSource_Split(t *testing.T) { ec.Encode(&FullValue{Elm: v}, pw) } pw.Close() - }(c, pw, elements) - return source, out, pr + }(c, cw, elements) + return source, out, cw.Ch } tests := []struct { @@ -289,12 +290,12 @@ func TestDataSource_Split(t *testing.T) { test.expected = elements[:test.splitIdx] } t.Run(test.name, func(t *testing.T) { - source, out, pr := initSourceTest(test.name) + source, out, ch := initSourceTest(test.name) p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() // StartBundle resets the source, so no splits can be actuated before then, @@ -358,7 +359,7 @@ func TestDataSource_Split(t *testing.T) { test.expected = elements[:test.splitIdx] } t.Run(test.name, func(t *testing.T) { - source, out, pr := initSourceTest(test.name) + source, out, ch := initSourceTest(test.name) unblockCh, blockedCh := make(chan struct{}), make(chan struct{}, 1) // Block on the one less than the desired split, // so the desired split is the first valid split. @@ -401,7 +402,7 @@ func TestDataSource_Split(t *testing.T) { }() constructAndExecutePlanWithContext(t, []Unit{out, blocker, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: ch}, }) validateSource(t, out, source, makeValues(test.expected...)) @@ -427,12 +428,12 @@ func TestDataSource_Split(t *testing.T) { expected: elements[:3], } - source, out, pr := initSourceTest("bufSize") + source, out, ch := initSourceTest("bufSize") p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() // StartBundle resets the source, so no splits can be actuated before then, @@ -490,7 +491,7 @@ func TestDataSource_Split(t *testing.T) { test := test name := fmt.Sprintf("withFraction_%v", test.fraction) t.Run(name, func(t *testing.T) { - source, out, pr := initSourceTest(name) + source, out, ch := initSourceTest(name) unblockCh, blockedCh := make(chan struct{}), make(chan struct{}, 1) // Block on the one less than the desired split, // so the desired split is the first valid split. @@ -527,10 +528,10 @@ func TestDataSource_Split(t *testing.T) { t.Errorf("error in Split: got sub-element split = %t, want %t", isSubElm, test.isSubElm) } if isSubElm { - if got, want := splitRes.TId, testTransformId; got != want { + if got, want := splitRes.TId, testTransformID; got != want { t.Errorf("error in Split: got incorrect Transform Id = %v, want %v", got, want) } - if got, want := splitRes.InId, testInputId; got != want { + if got, want := splitRes.InId, testInputID; got != want { t.Errorf("error in Split: got incorrect Input Id = %v, want %v", got, want) } if _, ok := splitRes.OW["output1"]; !ok { @@ -558,7 +559,7 @@ func TestDataSource_Split(t *testing.T) { }() constructAndExecutePlanWithContext(t, []Unit{out, blocker, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: ch}, }) validateSource(t, out, source, makeValues(elements[:test.splitIdx]...)) @@ -571,12 +572,12 @@ func TestDataSource_Split(t *testing.T) { // Test expects splitting errors, but for processing to be successful. t.Run("errors", func(t *testing.T) { - source, out, pr := initSourceTest("noSplitsUntilStarted") + source, out, ch := initSourceTest("noSplitsUntilStarted") p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0, 3}, Frac: -1}); err != nil || !sr.Unsuccessful { @@ -620,8 +621,8 @@ func TestDataSource_Split(t *testing.T) { }) } -const testTransformId = "transform_id" -const testInputId = "input_id" +const testTransformID = "transform_id" +const testInputID = "input_id" // TestSplittableUnit is an implementation of the SplittableUnit interface // for DataSource tests. @@ -651,12 +652,12 @@ func (n *TestSplittableUnit) GetProgress() float64 { // GetTransformId returns a constant transform ID that can be tested for. func (n *TestSplittableUnit) GetTransformId() string { - return testTransformId + return testTransformID } // GetInputId returns a constant input ID that can be tested for. func (n *TestSplittableUnit) GetInputId() string { - return testInputId + return testInputID } // GetOutputWatermark gets the current output watermark of the splittable unit @@ -966,20 +967,21 @@ func TestCheckpointing(t *testing.T) { } enc := MakeElementEncoder(wvERSCoder) - var buf bytes.Buffer + cw := makeChanWriter() // We encode the element several times to ensure we don't // drop any residuals, the root of issue #24931. wantCount := 3 for i := 0; i < wantCount; i++ { - if err := enc.Encode(value, &buf); err != nil { + if err := enc.Encode(value, cw); err != nil { t.Fatalf("couldn't encode value: %v", err) } } + cw.Close() if err := root.StartBundle(ctx, "testBund", DataContext{ Data: &TestDataManager{ - R: io.NopCloser(&buf), + Ch: cw.Ch, }, }, ); err != nil { @@ -1017,17 +1019,44 @@ func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd fu } type TestDataManager struct { - R io.ReadCloser + Ch chan Elements } -func (dm *TestDataManager) OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) { - return dm.R, nil +func (dm *TestDataManager) OpenElementChan(ctx context.Context, id StreamID, expectedTimerTransforms []string) (<-chan Elements, error) { + return dm.Ch, nil } func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) { return nil, nil } +func (dm *TestDataManager) OpenTimerWrite(ctx context.Context, id StreamID, family string) (io.WriteCloser, error) { + return nil, nil +} + +type chanWriter struct { + Ch chan Elements + Buf []byte +} + +func (cw *chanWriter) Write(p []byte) (int, error) { + cw.Buf = append(cw.Buf, p...) + return len(p), nil +} + +func (cw *chanWriter) Close() error { + cw.Flush() + close(cw.Ch) + return nil +} + +func (cw *chanWriter) Flush() { + cw.Ch <- Elements{Data: cw.Buf, PtransformID: "myPTransform"} + cw.Buf = nil +} + +func makeChanWriter() *chanWriter { return &chanWriter{Ch: make(chan Elements, 20)} } + // TestSideInputReader simulates state reads using channels. type TestStateReader struct { StateReader diff --git a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go index 1fa3ae94c866..84c84a8d3164 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go @@ -75,10 +75,10 @@ func TestDynamicSplit(t *testing.T) { plan, out := createSdfPlan(t, t.Name(), dfn, cdr) // Create thread to send element to pipeline. - pr, pw := io.Pipe() + cw := makeChanWriter() elm := createElm() - go writeElm(elm, cdr, pw) - dc := DataContext{Data: &TestDataManager{R: pr}} + go writeElm(elm, cdr, cw) + dc := DataContext{Data: &TestDataManager{Ch: cw.Ch}} // Call driver to coordinate processing & splitting threads. splitRes, procRes := test.driver(context.Background(), plan, dc, sdf) @@ -92,7 +92,7 @@ func TestDynamicSplit(t *testing.T) { RI: 1, PS: nil, RS: nil, - TId: testTransformId, + TId: testTransformID, InId: indexToInputId(0), } if diff := cmp.Diff(splitRes.split, wantSplit, cmpopts.IgnoreFields(SplitResult{}, "PS", "RS")); diff != "" { @@ -263,7 +263,7 @@ func createSplitTestInCoder() *coder.Coder { func createSdfPlan(t *testing.T, name string, fn *graph.DoFn, cdr *coder.Coder) (*Plan, *CaptureNode) { out := &CaptureNode{UID: 0} n := &ParDo{UID: 1, Fn: fn, Out: []Node{out}} - sdf := &ProcessSizedElementsAndRestrictions{PDo: n, TfId: testTransformId} + sdf := &ProcessSizedElementsAndRestrictions{PDo: n, TfId: testTransformID} ds := &DataSource{ UID: 2, SID: StreamID{PtransformID: "DataSource"}, @@ -281,8 +281,8 @@ func createSdfPlan(t *testing.T, name string, fn *graph.DoFn, cdr *coder.Coder) } // writeElm is meant to be the goroutine for feeding an element to the -// DataSourc of the test pipeline. -func writeElm(elm *FullValue, cdr *coder.Coder, pw *io.PipeWriter) { +// DataSource of the test pipeline. +func writeElm(elm *FullValue, cdr *coder.Coder, pw io.WriteCloser) { wc := MakeWindowEncoder(cdr.Window) ec := MakeElementEncoder(coder.SkipW(cdr)) if err := EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), pw); err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index 5a9b536b2889..6d9b9c00452b 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -17,8 +17,10 @@ package harness import ( "context" + "fmt" "io" "sync" + "sync/atomic" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" @@ -35,8 +37,9 @@ const ( // ScopedDataManager scopes the global gRPC data manager to a single instruction. // The indirection makes it easier to control access. type ScopedDataManager struct { - mgr *DataChannelManager - instID instructionID + mgr *DataChannelManager + instID instructionID + openPorts []exec.Port closed bool mu sync.Mutex @@ -47,22 +50,31 @@ func NewScopedDataManager(mgr *DataChannelManager, instID instructionID) *Scoped return &ScopedDataManager{mgr: mgr, instID: instID} } -// OpenRead opens an io.ReadCloser on the given stream. -func (s *ScopedDataManager) OpenRead(ctx context.Context, id exec.StreamID) (io.ReadCloser, error) { +// OpenWrite opens an io.WriteCloser on the given stream. +func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenRead(ctx, id.PtransformID, s.instID), nil + return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil } -// OpenWrite opens an io.WriteCloser on the given stream. -func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { +// OpenElementChan returns a channel of exec.Elements on the given stream. +func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamID, expectedTimerTransforms []string) (<-chan exec.Elements, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil + return ch.OpenElementChan(ctx, id.PtransformID, s.instID, expectedTimerTransforms) +} + +// OpenTimerWrite opens an io.WriteCloser on the given stream to write timers +func (s *ScopedDataManager) OpenTimerWrite(ctx context.Context, id exec.StreamID, family string) (io.WriteCloser, error) { + ch, err := s.open(ctx, id.Port) + if err != nil { + return nil, err + } + return ch.OpenTimerWrite(ctx, id.PtransformID, s.instID, family), nil } func (s *ScopedDataManager) open(ctx context.Context, port exec.Port) (*DataChannel, error) { @@ -71,6 +83,7 @@ func (s *ScopedDataManager) open(ctx context.Context, port exec.Port) (*DataChan s.mu.Unlock() return nil, errors.Errorf("instruction %v no longer processing", s.instID) } + s.openPorts = append(s.openPorts, port) local := s.mgr s.mu.Unlock() @@ -82,9 +95,9 @@ func (s *ScopedDataManager) Close() error { s.mu.Lock() defer s.mu.Unlock() s.closed = true - s.mgr.closeInstruction(s.instID) + err := s.mgr.closeInstruction(s.instID, s.openPorts) s.mgr = nil - return nil + return err } // DataChannelManager manages data channels over the Data API. A fixed number of channels @@ -124,28 +137,37 @@ func (m *DataChannelManager) Open(ctx context.Context, port exec.Port) (*DataCha return ch, nil } -func (m *DataChannelManager) closeInstruction(instID instructionID) { +func (m *DataChannelManager) closeInstruction(instID instructionID, ports []exec.Port) error { m.mu.Lock() defer m.mu.Unlock() - for _, ch := range m.ports { - ch.removeInstruction(instID) + var firstNonNilError error + for _, port := range ports { + ch, ok := m.ports[port.URL] + if !ok { + continue + } + err := ch.removeInstruction(instID) + if err != nil && firstNonNilError == nil { + firstNonNilError = err + } } + return firstNonNilError } // clientID identifies a client of a connected channel. type clientID struct { - ptransformID string instID instructionID + ptransformID string } // This is a reduced version of the full gRPC interface to help with testing. -// TODO(wcn): need a compile-time assertion to make sure this stays synced with what's -// in fnpb.BeamFnData_DataClient type dataClient interface { Send(*fnpb.Elements) error Recv() (*fnpb.Elements, error) } +var _ dataClient = (fnpb.BeamFnData_DataClient)(nil) // Assert our interfaces are compatible. + // DataChannel manages a single gRPC stream over the Data API. Data from // multiple bundles can be multiplexed over this stream. Data is pushed // over the channel, so data for a reader may arrive before the reader @@ -155,8 +177,9 @@ type DataChannel struct { id string client dataClient - writers map[instructionID]map[string]*dataWriter - readers map[instructionID]map[string]*dataReader + writers map[instructionID]map[string]*dataWriter // PTransformID + timerWriters map[instructionID]map[timerKey]*timerWriter + channels map[instructionID]*elementsChan // recently terminated instructions endedInstructions map[instructionID]struct{} @@ -172,6 +195,58 @@ type DataChannel struct { mu sync.Mutex // guards mutable internal data, notably the maps and readErr. } +type timerKey struct { + ptransformID, family string +} + +// elementsChan abstracts the management for this instruction's channel. +// +// The only runner signal that all data for an instruction has been received +// is when the ch channel has been closed. However, we may receive all data +// before the instruction begins consuming it, and there may be multiple PTransforms +// in the instruction that may need data through this channel. Until the instruction +// arrives, received data needs to be cached, and we cannot close the channel. +// +// The channel may only close if the want == got and want > 0. +// want is set once when the Source requests it. +// got is incremented only if we receive an IsLast signal for a given +// instruction/transform pair. +type elementsChan struct { + closed uint32 // Closed if != 0 + instID instructionID + + mu sync.Mutex + want, got int32 + + ch chan exec.Elements // must only be closed by the read loop + + done chan struct{} // Forces escape from a blocked write to allow channel close. +} + +// InstructionEnded signals the read loop to close the channel. +func (ec *elementsChan) InstructionEnded() { + close(ec.done) +} + +// Closed indicates if all expected streams are complete +func (ec *elementsChan) Closed() bool { + return atomic.LoadUint32(&ec.closed) != 0 +} + +// PTransformDone signals that a PTransform has no more data coming to it. +// If permitted, PTransformDone closes the channel. +func (ec *elementsChan) PTransformDone() { + ec.mu.Lock() + defer ec.mu.Unlock() + ec.got++ + if ec.want > 0 && ec.want == ec.got { + if !ec.Closed() { + atomic.StoreUint32(&ec.closed, 1) + close(ec.ch) + } + } +} + func newDataChannel(ctx context.Context, port exec.Port) (*DataChannel, error) { ctx, cancelFn := context.WithCancel(ctx) cc, err := dial(ctx, port.URL, "data", 15*time.Second) @@ -196,7 +271,8 @@ func makeDataChannel(ctx context.Context, id string, client dataClient, cancelFn id: id, client: client, writers: make(map[instructionID]map[string]*dataWriter), - readers: make(map[instructionID]map[string]*dataReader), + timerWriters: make(map[instructionID]map[timerKey]*timerWriter), + channels: make(map[instructionID]*elementsChan), endedInstructions: make(map[instructionID]struct{}), cancelFn: cancelFn, } @@ -214,25 +290,68 @@ func (c *DataChannel) terminateStreamOnError(err error) { } } -// OpenRead returns an io.ReadCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenRead(ctx context.Context, ptransformID string, instID instructionID) io.ReadCloser { +// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. +func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { + return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +} + +// OpenElementChan returns a channel of typex.Elements for the given instruction and ptransform. +func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID, expectedTimerTransforms []string) (<-chan exec.Elements, error) { c.mu.Lock() defer c.mu.Unlock() cid := clientID{ptransformID: ptransformID, instID: instID} if c.readErr != nil { - log.Errorf(ctx, "opening a reader %v on a closed channel", cid) - return &errReader{c.readErr} + return nil, fmt.Errorf("opening a reader %v on a closed channel. Original error: %w", cid, c.readErr) } - return c.makeReader(ctx, cid) + return c.makeChannel(true, cid, expectedTimerTransforms...).ch, nil } -// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { - return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +// makeChannel creates a channel of exec.Elements. It expects to be called while c.mu is held. +func (c *DataChannel) makeChannel(fromSource bool, id clientID, additionalTransforms ...string) *elementsChan { + if ec, ok := c.channels[id.instID]; ok { + ec.mu.Lock() + defer ec.mu.Unlock() + if fromSource { + ec.want = (1 + int32(len(additionalTransforms))) + } + if _, ok := c.endedInstructions[id.instID]; ok || (ec.want > 0 && ec.want == ec.got) { + atomic.StoreUint32(&ec.closed, 1) + close(ec.ch) + } + return ec + } + + ec := &elementsChan{ + instID: id.instID, + ch: make(chan exec.Elements, 20), + done: make(chan struct{}), + } + if fromSource { + ec.want = 1 + int32(len(additionalTransforms)) + } + + // Just in case initial data for an instruction arrives *after* an instructon has ended. + // eg. it was blocked by another reader being slow, or the other instruction failed. + // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. + if _, ok := c.endedInstructions[id.instID]; ok { + // Since this is freshly created, we can set the close conditions immeadiately. + atomic.StoreUint32(&ec.closed, 1) + close(ec.ch) + return ec + } + + c.channels[id.instID] = ec + return ec +} + +// OpenTimerWrite returns io.WriteCloser for the given timerFamilyID, instruction and ptransform. +func (c *DataChannel) OpenTimerWrite(ctx context.Context, ptransformID string, instID instructionID, family string) io.WriteCloser { + return c.makeTimerWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}, family) } func (c *DataChannel) read(ctx context.Context) { - cache := make(map[clientID]*dataReader) + cache := make(map[instructionID]*elementsChan) + seenLast := make([]clientID, 0, 5) for { msg, err := c.client.Recv() if err != nil { @@ -240,25 +359,19 @@ func (c *DataChannel) read(ctx context.Context) { c.mu.Lock() c.readErr = err // prevent not yet opened readers from hanging. // Readers must be closed from this goroutine, since we can't - // close the r.buf channels twice, or send on a closed channel. - // Any other approach is racy, and may cause one of the above - // panics. - for _, m := range c.readers { - for _, r := range m { - log.Errorf(ctx, "DataChannel.read %v reader %v closing due to error on channel", c.id, r.id) - if !r.completed { - r.completed = true - r.err = err - close(r.buf) - } - delete(cache, r.id) + // close the elementsChan channel twice, or send on those closed channels. + // Any other approach is racy, and may cause one of the above panics. + for instID, ec := range c.channels { + if !ec.Closed() { + atomic.StoreUint32(&ec.closed, 1) + close(ec.ch) } + delete(cache, instID) } c.terminateStreamOnError(err) c.mu.Unlock() if err == io.EOF { - log.Warnf(ctx, "DataChannel.read %v closed", c.id) return } log.Errorf(ctx, "DataChannel.read %v bad: %v", c.id, err) @@ -270,120 +383,87 @@ func (c *DataChannel) read(ctx context.Context) { // Each message may contain segments for multiple streams, so we // must treat each segment in isolation. We maintain a local cache // to reduce lock contention. - - for _, elm := range msg.GetData() { - id := clientID{ptransformID: elm.TransformId, instID: instructionID(elm.GetInstructionId())} - - var r *dataReader - if local, ok := cache[id]; ok { - r = local - } else { - c.mu.Lock() - r = c.makeReader(ctx, id) - c.mu.Unlock() - cache[id] = r - } - - if elm.GetIsLast() { - // If this reader hasn't closed yet, do so now. - if !r.completed { - // Use the last segment if any. - if len(elm.GetData()) != 0 { - // In case of local side closing, send with select. - select { - case r.buf <- elm.GetData(): - case <-r.done: - } - } - // Close buffer to signal EOF. - r.completed = true - close(r.buf) + iterateElements(c, cache, &seenLast, msg.GetTimers(), + func(elm *fnpb.Elements_Timers) exec.Elements { + return exec.Elements{Timers: elm.GetTimers(), PtransformID: elm.GetTransformId(), TimerFamilyID: elm.GetTimerFamilyId()} + }) + + iterateElements(c, cache, &seenLast, msg.GetData(), + func(elm *fnpb.Elements_Data) exec.Elements { + return exec.Elements{Data: elm.GetData(), PtransformID: elm.GetTransformId()} + }) + + // Mark all readers that we've seen the last of as done, after queuing their elements. + if len(seenLast) > 0 { + c.mu.Lock() + for _, id := range seenLast { + r, ok := cache[id.instID] + if !ok { + continue // we've already closed this cached reader, skip + } + r.PTransformDone() + if r.Closed() { + // Clean up local bookkeeping. We'll never see another message + // for it again. We have to be careful not to remove the real + // one, because readers may be initialized after we've seen + // the full stream. + delete(cache, id.instID) } - - // Clean up local bookkeeping. We'll never see another message - // for it again. We have to be careful not to remove the real - // one, because readers may be initialized after we've seen - // the full stream. - delete(cache, id) - continue - } - - if r.completed { - // The local reader has closed but the remote is still sending data. - // Just ignore it. We keep the reader config in the cache so we don't - // treat it as a new reader. Eventually the stream will finish and go - // through normal teardown. - continue - } - - // This send is deliberately blocking, if we exceed the buffering for - // a reader. We can't buffer the entire main input, if some user code - // is slow (or gets stuck). If the local side closes, the reader - // will be marked as completed and further remote data will be ignored. - select { - case r.buf <- elm.GetData(): - case <-r.done: - r.completed = true - close(r.buf) } + seenLast = seenLast[:0] // reset for re-use + c.mu.Unlock() } } } -type errReader struct { - err error +// dataEle is a light interface against the proto Data and Timer Elements. +type dataEle interface { + GetTransformId() string + GetInstructionId() string + GetIsLast() bool } -func (r *errReader) Read(_ []byte) (int, error) { - return 0, r.err -} - -func (r *errReader) Close() error { - return r.err -} - -// makeReader creates a dataReader. It expects to be called while c.mu is held. -func (c *DataChannel) makeReader(ctx context.Context, id clientID) *dataReader { - var m map[string]*dataReader - var ok bool - if m, ok = c.readers[id.instID]; !ok { - m = make(map[string]*dataReader) - c.readers[id.instID] = m - } - - if r, ok := m[id.ptransformID]; ok { - return r - } - - r := &dataReader{id: id, buf: make(chan []byte, bufElements), done: make(chan bool, 1), channel: c} +func iterateElements[E dataEle](c *DataChannel, cache map[instructionID]*elementsChan, seenLast *[]clientID, elms []E, wrap func(E) exec.Elements) { + for _, elm := range elms { + id := clientID{ptransformID: elm.GetTransformId(), instID: instructionID(elm.GetInstructionId())} - // Just in case initial data for an instruction arrives *after* an instructon has ended. - // eg. it was blocked by another reader being slow, or the other instruction failed. - // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. - if _, ok := c.endedInstructions[id.instID]; ok { - r.completed = true - close(r.buf) - r.err = io.EOF // In case of any actual data readers, so they terminate without error. - return r - } + var ec *elementsChan + if local, ok := cache[id.instID]; ok { + ec = local + } else { + c.mu.Lock() + ec = c.makeChannel(false, id) + c.mu.Unlock() + cache[id.instID] = ec + } - m[id.ptransformID] = r - return r -} + if ec.Closed() { + continue + } -func (c *DataChannel) removeReader(id clientID) { - c.mu.Lock() - if m, ok := c.readers[id.instID]; ok { - delete(m, id.ptransformID) + // This send deliberately blocks if we exceed the buffering for + // a reader. We can't buffer the entire main input, if some user code + // is slow (or gets stuck). If the local side closes, the reader + // will be marked as completed and further remote data will be ignored. + select { + case ec.ch <- wrap(elm): + case <-ec.done: // In case of out of band cancels. + ec.mu.Lock() + atomic.StoreUint32(&ec.closed, 1) + close(ec.ch) + ec.mu.Unlock() + } + if elm.GetIsLast() { + *seenLast = append(*seenLast, id) + } } - c.mu.Unlock() } const endedInstructionCap = 32 // removeInstruction closes all readers and writers registered for the instruction // and deletes this instruction from the channel's reader and writer maps. -func (c *DataChannel) removeInstruction(instID instructionID) { +func (c *DataChannel) removeInstruction(instID instructionID) error { c.mu.Lock() // We don't want to leak memory, so cap the endedInstructions list. @@ -395,21 +475,29 @@ func (c *DataChannel) removeInstruction(instID instructionID) { c.endedInstructions[instID] = struct{}{} c.rmQueue = append(c.rmQueue, instID) - rs := c.readers[instID] ws := c.writers[instID] + tws := c.timerWriters[instID] + ec := c.channels[instID] // Prevent other users while we iterate. - delete(c.readers, instID) delete(c.writers, instID) + delete(c.timerWriters, instID) + delete(c.channels, instID) + + // Return readErr to defend against data loss via short reads. + err := c.readErr c.mu.Unlock() - // Close grabs the channel lock, so this must be outside the critical section. - for _, r := range rs { - r.Close() - } for _, w := range ws { w.Close() } + for _, tw := range tws { + tw.Close() + } + if ec != nil { + ec.InstructionEnded() + } + return err } func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { @@ -436,48 +524,6 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { return w } -type dataReader struct { - id clientID - buf chan []byte - done chan bool - cur []byte - channel *DataChannel - completed bool - err error -} - -func (r *dataReader) Close() error { - r.done <- true - r.channel.removeReader(r.id) - return nil -} - -func (r *dataReader) Read(buf []byte) (int, error) { - if r.cur == nil { - b, ok := <-r.buf - if !ok { - if r.err == nil { - return 0, io.EOF - } - return 0, r.err - } - r.cur = b - } - - // We don't need to check for a 0 length copy from r.cur here, since that's - // checked before buffers are handed to the r.buf channel. - n := copy(buf, r.cur) - - switch { - case len(r.cur) == n: - r.cur = nil - default: - r.cur = r.cur[n:] - } - - return n, nil -} - type dataWriter struct { buf []byte @@ -574,3 +620,99 @@ func (w *dataWriter) Write(p []byte) (n int, err error) { w.buf = append(w.buf, p...) return len(p), nil } + +func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID, family string) *timerWriter { + c.mu.Lock() + defer c.mu.Unlock() + + var m map[timerKey]*timerWriter + var ok bool + if m, ok = c.timerWriters[id.instID]; !ok { + m = make(map[timerKey]*timerWriter) + c.timerWriters[id.instID] = m + } + tk := timerKey{ptransformID: id.ptransformID, family: family} + if w, ok := m[tk]; ok { + return w + } + + // We don't check for finished instructions for writers, as writers + // can only be created if an instruction is in scope, and aren't + // runner or user directed. + + w := &timerWriter{ch: c, id: id, timerFamilyID: family} + m[tk] = w + return w +} + +type timerWriter struct { + id clientID + timerFamilyID string + ch *DataChannel +} + +// send requires the ch.mu lock to be held. +func (w *timerWriter) send(msg *fnpb.Elements) error { + recordStreamSend(msg) + if err := w.ch.client.Send(msg); err != nil { + if err == io.EOF { + log.Warnf(context.TODO(), "timerWriter[%v;%v] EOF on send; fetching real error", w.id, w.ch.id) + err = nil + for err == nil { + // Per GRPC stream documentation, if there's an EOF, we must call Recv + // until a non-nil error is returned, to ensure resources are cleaned up. + // https://pkg.go.dev/google.golang.org/grpc#ClientConn.NewStream + _, err = w.ch.client.Recv() + } + } + log.Warnf(context.TODO(), "timerWriter[%v;%v] error on send: %v", w.id, w.ch.id, err) + w.ch.terminateStreamOnError(err) + return err + } + return nil +} + +func (w *timerWriter) Close() error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + delete(w.ch.timerWriters[w.id.instID], timerKey{w.id.ptransformID, w.timerFamilyID}) + var msg *fnpb.Elements + msg = &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.timerFamilyID, + IsLast: true, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) writeTimers(p []byte) error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + + log.Infof(context.TODO(), "DEBUGLOG: timer write for %+v: %v", w.id, p) + + msg := &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.timerFamilyID, + Timers: p, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) Write(p []byte) (n int, err error) { + // write timers directly without buffering. + if err := w.writeTimers(p); err != nil { + return 0, err + } + return len(p), nil +} diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index f69d9abde49b..c7f8ac5858c1 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -21,11 +21,13 @@ import ( "fmt" "io" "log" + "runtime" "strings" "sync" "testing" "time" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" ) @@ -101,103 +103,349 @@ func (f *fakeDataClient) Send(*fnpb.Elements) error { return nil } -func TestDataChannelTerminate_dataReader(t *testing.T) { - // The logging of channels closed is quite noisy for this test - log.SetOutput(io.Discard) +type fakeChanClient struct { + ch chan *fnpb.Elements + err error +} - expectedError := fmt.Errorf("EXPECTED ERROR") +func (f *fakeChanClient) Recv() (*fnpb.Elements, error) { + e, ok := <-f.ch + if !ok { + return nil, f.err + } + return e, nil +} + +func (f *fakeChanClient) Send(e *fnpb.Elements) error { + f.ch <- e + return nil +} + +func (f *fakeChanClient) Close() { + f.err = io.EOF + close(f.ch) +} + +func (f *fakeChanClient) CloseWith(err error) { + f.err = err + close(f.ch) +} + +func TestElementChan(t *testing.T) { + const instID = "inst_ref" + dataID := "dataTransform" + timerID := "timerTransform" + timerFamily := "timerFamily" + setupClient := func(t *testing.T) (context.Context, *fakeChanClient, *DataChannel) { + t.Helper() + client := &fakeChanClient{ch: make(chan *fnpb.Elements, bufElements)} + ctx, cancelFn := context.WithCancel(context.Background()) + t.Cleanup(cancelFn) + t.Cleanup(func() { client.Close() }) + + c := makeDataChannel(ctx, "id", client, cancelFn) + return ctx, client, c + } + drainAndSum := func(t *testing.T, elms <-chan exec.Elements) (sum, count int) { + t.Helper() + for e := range elms { // only exits if data channel is closed. + if len(e.Data) != 0 { + sum += int(e.Data[0]) + count++ + } + if len(e.Timers) != 0 { + if e.TimerFamilyID != timerFamily { + t.Errorf("timer received without family set: %v, state= sum %v, count %v", e, sum, count) + } + sum += int(e.Timers[0]) + count++ + } + } + return sum, count + } + timerElm := func(val byte, isLast bool) *fnpb.Elements_Timers { + return &fnpb.Elements_Timers{InstructionId: instID, TransformId: timerID, Timers: []byte{val}, IsLast: isLast, TimerFamilyId: timerFamily} + } + dataElm := func(val byte, isLast bool) *fnpb.Elements_Data { + return &fnpb.Elements_Data{InstructionId: instID, TransformId: dataID, Data: []byte{val}, IsLast: isLast} + } + noTimerElm := func() *fnpb.Elements_Timers { + return &fnpb.Elements_Timers{InstructionId: instID, TransformId: timerID, Timers: []byte{}, IsLast: true} + } + noDataElm := func() *fnpb.Elements_Data { + return &fnpb.Elements_Data{InstructionId: instID, TransformId: dataID, Data: []byte{}, IsLast: true} + } + openChan := func(ctx context.Context, t *testing.T, c *DataChannel, timers ...string) <-chan exec.Elements { + t.Helper() + runtime.Gosched() // Encourage the "read" goroutine to schedule before this call, if necessary. + elms, err := c.OpenElementChan(ctx, dataID, instID, timers) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + return elms + } + + // Most Cases tests := []struct { - name string - expectedError error - caseFn func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) + name string + sequenceFn func(context.Context, *testing.T, *fakeChanClient, *DataChannel) <-chan exec.Elements + wantSum, wantCount int }{ { - name: "onClose", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // We don't read up all the buffered data, but immediately close the reader. - // Previously, since nothing was consuming the incoming gRPC data, the whole - // data channel would get stuck, and the client.Recv() call was eventually - // no longer called. - r.Close() - - // If done is signaled, that means client.Recv() has been called to flush the - // channel, meaning consumer code isn't stuck. - <-client.done + name: "ReadThenData_singleRecv", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + elms := openChan(ctx, t, c) + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + dataElm(3, true), + }, + }) + return elms + }, + wantSum: 6, wantCount: 3, + }, { + name: "ReadThenData_multipleRecv", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + elms := openChan(ctx, t, c) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(2, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + return elms + }, + wantSum: 6, wantCount: 3, + }, { + name: "ReadThenNoData", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + elms := openChan(ctx, t, c) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + return elms + }, + wantSum: 0, wantCount: 0, + }, { + name: "NoDataThenRead", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + elms := openChan(ctx, t, c) + return elms + }, + wantSum: 0, wantCount: 0, + }, { + name: "NoDataInstEndsThenRead", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + c.removeInstruction(instID) + elms := openChan(ctx, t, c) + return elms + }, + wantSum: 0, wantCount: 0, + }, { + name: "ReadThenDataAndTimers", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + elms := openChan(ctx, t, c, timerID) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(2, true)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + return elms + }, + wantSum: 6, wantCount: 3, + }, { + name: "AllDataAndTimersThenRead", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(2, true)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + elms := openChan(ctx, t, c, timerID) + return elms }, + wantSum: 6, wantCount: 3, }, { - name: "onSentinel", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // fakeDataClient eventually returns a sentinel element. + name: "FillBufferThenAbortThenRead", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + for i := 0; i < bufElements+2; i++ { + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + } + elms := openChan(ctx, t, c, timerID) + c.removeInstruction(instID) + + // These will be ignored + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(2, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + return elms + }, + wantSum: bufElements, wantCount: bufElements, + }, { + name: "DataThenReaderThenLast", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + }, + }) + elms := openChan(ctx, t, c) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + return elms }, + wantSum: 6, wantCount: 3, }, { - name: "onIsLast_withData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the last call with data to use is_last. - client.isLastCall = 2 + name: "PartialTimersAllDataReadThenLastTimer", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + timerElm(1, false), + timerElm(2, false), + }, + }) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + + elms := openChan(ctx, t, c, timerID) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(3, true)}}) + + return elms }, + wantSum: 6, wantCount: 3, }, { - name: "onIsLast_withoutData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the call without data to use is_last. - client.isLastCall = 3 + name: "AllTimerThenReaderThenDataClose", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + timerElm(1, false), + timerElm(2, false), + timerElm(3, true), + }, + }) + + elms := openChan(ctx, t, c, timerID) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + + return elms }, + wantSum: 6, wantCount: 3, }, { - name: "onRecvError", - expectedError: expectedError, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // The SDK starts reading in a goroutine immeadiately after open. - // Set the 2nd Recv call to have an error. - client.err = expectedError + name: "NoTimersThenReaderThenNoData", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{noTimerElm()}}) + elms := openChan(ctx, t, c, timerID) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + return elms }, + wantSum: 0, wantCount: 0, }, { - name: "onInstructionEnd", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - c.removeInstruction("inst_ref") + name: "SomeTimersThenReaderThenAData", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}}) + elms := openChan(ctx, t, c, timerID) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + return elms }, + wantSum: 6, wantCount: 3, + }, { + name: "SomeTimersAndADataThenReader", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}, + Data: []*fnpb.Elements_Data{dataElm(3, true)}, + }) + elms := openChan(ctx, t, c, timerID) + return elms + }, + wantSum: 6, wantCount: 3, + }, { + name: "PartialReadThenEndInstruction", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + }, + }) + elms := openChan(ctx, t, c) + var sum int + e := <-elms + sum += int(e.Data[0]) + e = <-elms + sum += int(e.Data[0]) + + if got, want := sum, 3; got != want { + t.Errorf("got sum %v, want sum %v", got, want) + } + + // Simulate a split, where the remaining buffer wouldn't be read further, and the instruction ends. + c.removeInstruction(instID) + + // Instruction is ended, so further data for this instruction is ignored. + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(3, false), + dataElm(4, true), + }, + }) + + elms = openChan(ctx, t, c) + return elms + }, + wantSum: 0, wantCount: 0, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} + ctx, client, c := setupClient(t) + elms := test.sequenceFn(ctx, t, client, c) + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := test.wantSum, test.wantCount; sum != wantSum || count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantCount) + } + }) + } +} + +func BenchmarkElementChan(b *testing.B) { + benches := []struct { + size int + }{ + {1}, + {10}, + {100}, + {1000}, + {10000}, + } + + for _, bench := range benches { + b.Run(fmt.Sprintf("batchSize:%v", bench.size), func(b *testing.B) { + client := &fakeChanClient{ch: make(chan *fnpb.Elements, bufElements)} ctx, cancelFn := context.WithCancel(context.Background()) c := makeDataChannel(ctx, "id", client, cancelFn) - r := c.OpenRead(ctx, "ptr", "inst_ref") - - n, err := r.Read(make([]byte, 4)) + const instID = "inst_ref" + dataID := "dataTransform" + elms, err := c.OpenElementChan(ctx, dataID, instID, nil) if err != nil { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) - } - test.caseFn(t, r, client, c) - // Drain the reader. - i := 1 // For the earlier Read. - for err == nil { - read := make([]byte, 4) - _, err = r.Read(read) - i++ - } - - if got, want := err, test.expectedError; got != want { - t.Errorf("Unexpected error from read %d: got %v, want %v", i, got, want) + b.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) } - // Verify that new readers return the same error on their reads after client.Recv is done. - if n, err := c.OpenRead(ctx, "ptr", "inst_ref").Read(make([]byte, 4)); err != test.expectedError { - t.Errorf("Unexpected error from read: got %v, want, %v read %d bytes.", err, test.expectedError, n) + e := &fnpb.Elements_Data{InstructionId: instID, TransformId: dataID, Data: []byte{1}, IsLast: false} + es := make([]*fnpb.Elements_Data, 0, bench.size) + for i := 0; i < bench.size; i++ { + es = append(es, e) } - - select { - case <-ctx.Done(): // Assert that the context must have been cancelled on read failures. - return - case <-time.After(time.Second * 5): - t.Fatal("context wasn't cancelled") + batch := &fnpb.Elements{Data: es} + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for range elms { + } + }() + // Batch elements sizes. + for i := 0; i < b.N; i += bench.size { + client.Send(batch) } + client.Close() + // Wait until we've consumed all sent batches. + wg.Wait() }) } } @@ -213,16 +461,9 @@ func TestDataChannelRemoveInstruction_dataAfterClose(t *testing.T) { client.blocked.Unlock() - r := c.OpenRead(ctx, "ptr", "inst_ref") - - dr := r.(*dataReader) - if !dr.completed || dr.err != io.EOF { - t.Errorf("Expected a closed reader, but was still open: completed: %v, err: %v", dr.completed, dr.err) - } - - n, err := r.Read(make([]byte, 4)) - if err != io.EOF { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) + _, err := c.OpenElementChan(ctx, "ptr", "inst_ref", nil) + if err != nil { + t.Errorf("Unexpected error from read: %v,", err) } } @@ -234,7 +475,7 @@ func TestDataChannelRemoveInstruction_limitInstructionCap(t *testing.T) { for i := 0; i < endedInstructionCap+10; i++ { instID := instructionID(fmt.Sprintf("inst_ref%d", i)) - c.OpenRead(ctx, "ptr", instID) + c.OpenElementChan(ctx, "ptr", instID, nil) c.removeInstruction(instID) } if got, want := len(c.endedInstructions), endedInstructionCap; got != want {