diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go index c60a8bf2a3f5..2d3425af33c6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go @@ -178,7 +178,7 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb ckvERSID: coder(urns.CoderKV, ckvERID, cSID), } - // PCollections only have two new ones. + // There are only two new PCollections. // INPUT -> same as ordinary DoFn // PWR, uses ckvER // SPLITnSIZED, uses ckvERS @@ -201,7 +201,7 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb nSPLITnSIZEDID: pcol(nSPLITnSIZEDID, ckvERSID), } - // PTransforms have 3 new ones, with process sized elements and restrictions + // There are 3 new PTransforms, with process sized elements and restrictions // taking the brunt of the complexity, consuming the inputs ePWRID := "e" + tid + "_pwr" @@ -209,15 +209,19 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb eProcessID := "e" + tid + "_processandsplit" tform := func(name, urn, in, out string) *pipepb.PTransform { + // Apparently we also send side inputs to PairWithRestriction + // and SplitAndSize. We should consider wether we could simply + // drop the side inputs from the ParDo payload instead, which + // could lead to an additional fusion oppportunity. + newInputs := maps.Clone(t.GetInputs()) + newInputs[inputLocalID] = in return &pipepb.PTransform{ UniqueName: name, Spec: &pipepb.FunctionSpec{ Urn: urn, Payload: pardoPayload, }, - Inputs: map[string]string{ - inputLocalID: in, - }, + Inputs: newInputs, Outputs: map[string]string{ "i0": out, }, diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 95f6af18ac74..ed7f168e36ee 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -18,6 +18,7 @@ package internal import ( "fmt" "sort" + "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" @@ -438,7 +439,8 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa t := comps.GetTransforms()[link.Transform] var sis map[string]*pipepb.SideInput - if t.GetSpec().GetUrn() == urns.TransformParDo { + switch t.GetSpec().GetUrn() { + case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate: pardo := &pipepb.ParDoPayload{} if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform) @@ -485,7 +487,17 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa // Quick check that this is lead by a flatten node, and that it's handled runner side. t := comps.GetTransforms()[stg.transforms[0]] if !(t.GetSpec().GetUrn() == urns.TransformFlatten && t.GetEnvironmentId() == "") { - return fmt.Errorf("expected runner flatten node, but wasn't: %v -- %v", stg.transforms, mainInputs) + formatMap := func(in map[string]string) string { + var b strings.Builder + for k, v := range in { + b.WriteString(k) + b.WriteString(" : ") + b.WriteString(v) + b.WriteString("\n\t") + } + return b.String() + } + return fmt.Errorf("stage requires multiple parallel inputs but wasn't a flatten:\n\ttransforms\n\t%v\n\tmain inputs\n\t%v\n\tsidinputs\n\t%v", strings.Join(stg.transforms, "\n\t\t"), formatMap(mainInputs), sideInputs) } } return nil diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index da374c96cafe..d4abed293534 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -275,7 +275,11 @@ progress: } func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) { - if t.GetSpec().GetUrn() != urns.TransformParDo { + switch t.GetSpec().GetUrn() { + case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate: + // Intentionally empty since these are permitted to have side inputs. + default: + // Nothing else is allowed to have side inputs. return nil, nil } // TODO, memoize this, so we don't need to repeatedly unmarshal. @@ -334,6 +338,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng return col } + // Update coders for Stateful transforms. for _, tid := range stg.transforms { t := comps.GetTransforms()[tid] @@ -461,10 +466,11 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng } // Update side inputs to point to new PCollection with any replaced coders. transforms[si.Transform].GetInputs()[si.Local] = newGlobal + // TODO: replace si.Global with newGlobal? } - prepSide, err := handleSideInput(si, comps, coders, em) + prepSide, err := handleSideInput(si, comps, transforms, pcollections, coders, em) if err != nil { - slog.Error("buildDescriptor: handleSideInputs", err, slog.String("transformID", si.Transform)) + slog.Error("buildDescriptor: handleSideInputs", "error", err, slog.String("transformID", si.Transform)) return err } prepareSides = append(prepareSides, prepSide) @@ -556,8 +562,8 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng } // handleSideInput returns a closure that will look up the data for a side input appropriate for the given watermark. -func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) { - t := comps.GetTransforms()[link.Transform] +func handleSideInput(link engine.LinkID, comps *pipepb.Components, transforms map[string]*pipepb.PTransform, pcols map[string]*pipepb.PCollection, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) { + t := transforms[link.Transform] sis, err := getSideInputs(t) if err != nil { return nil, err @@ -570,7 +576,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st slog.String("local", link.Local), slog.String("global", link.Global)) - col := comps.GetPcollections()[link.Global] + col := pcols[link.Global] // The returned coders are unused here, but they add the side input coders // to the stage components for use SDK side. @@ -594,7 +600,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st slog.String("sourceTransform", t.GetUniqueName()), slog.String("local", link.Local), slog.String("global", link.Global)) - col := comps.GetPcollections()[link.Global] + col := pcols[link.Global] kvc := comps.GetCoders()[col.GetCoderId()] if kvc.GetSpec().GetUrn() != urns.CoderKV { @@ -633,7 +639,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st }] = windowed }, nil default: - return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, si.GetAccessPattern().GetUrn()) + return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, prototext.Format(si.GetAccessPattern())) } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index d25c173e8c2f..f9ec03793488 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -452,7 +452,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { // TODO: move data handling to be pcollection based. key := req.GetStateKey() - slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b) + slog.Debug("StateRequest_Get", "request", prototext.Format(req), "bundle", b) var data [][]byte switch key.GetType().(type) { case *fnpb.StateKey_IterableSideInput_: