Skip to content

Commit

Permalink
[apache#31992][prism] Send side inputs for all SDF phases. (apache#32042
Browse files Browse the repository at this point in the history
)

* [apache#31992] Send side inputs for all SDF phases.

* delint.

---------

Co-authored-by: lostluck <[email protected]>
  • Loading branch information
2 people authored and reeba212 committed Dec 4, 2024
1 parent cc017ba commit f6bfe45
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
14 changes: 9 additions & 5 deletions sdks/go/pkg/beam/runners/prism/internal/handlepardo.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb
ckvERSID: coder(urns.CoderKV, ckvERID, cSID),
}

// PCollections only have two new ones.
// There are only two new PCollections.
// INPUT -> same as ordinary DoFn
// PWR, uses ckvER
// SPLITnSIZED, uses ckvERS
Expand All @@ -201,23 +201,27 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb
nSPLITnSIZEDID: pcol(nSPLITnSIZEDID, ckvERSID),
}

// PTransforms have 3 new ones, with process sized elements and restrictions
// There are 3 new PTransforms, with process sized elements and restrictions
// taking the brunt of the complexity, consuming the inputs

ePWRID := "e" + tid + "_pwr"
eSPLITnSIZEDID := "e" + tid + "_splitnsize"
eProcessID := "e" + tid + "_processandsplit"

tform := func(name, urn, in, out string) *pipepb.PTransform {
// Apparently we also send side inputs to PairWithRestriction
// and SplitAndSize. We should consider wether we could simply
// drop the side inputs from the ParDo payload instead, which
// could lead to an additional fusion oppportunity.
newInputs := maps.Clone(t.GetInputs())
newInputs[inputLocalID] = in
return &pipepb.PTransform{
UniqueName: name,
Spec: &pipepb.FunctionSpec{
Urn: urn,
Payload: pardoPayload,
},
Inputs: map[string]string{
inputLocalID: in,
},
Inputs: newInputs,
Outputs: map[string]string{
"i0": out,
},
Expand Down
16 changes: 14 additions & 2 deletions sdks/go/pkg/beam/runners/prism/internal/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package internal
import (
"fmt"
"sort"
"strings"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -438,7 +439,8 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
t := comps.GetTransforms()[link.Transform]

var sis map[string]*pipepb.SideInput
if t.GetSpec().GetUrn() == urns.TransformParDo {
switch t.GetSpec().GetUrn() {
case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate:
pardo := &pipepb.ParDoPayload{}
if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil {
return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform)
Expand Down Expand Up @@ -485,7 +487,17 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
// Quick check that this is lead by a flatten node, and that it's handled runner side.
t := comps.GetTransforms()[stg.transforms[0]]
if !(t.GetSpec().GetUrn() == urns.TransformFlatten && t.GetEnvironmentId() == "") {
return fmt.Errorf("expected runner flatten node, but wasn't: %v -- %v", stg.transforms, mainInputs)
formatMap := func(in map[string]string) string {
var b strings.Builder
for k, v := range in {
b.WriteString(k)
b.WriteString(" : ")
b.WriteString(v)
b.WriteString("\n\t")
}
return b.String()
}
return fmt.Errorf("stage requires multiple parallel inputs but wasn't a flatten:\n\ttransforms\n\t%v\n\tmain inputs\n\t%v\n\tsidinputs\n\t%v", strings.Join(stg.transforms, "\n\t\t"), formatMap(mainInputs), sideInputs)
}
}
return nil
Expand Down
22 changes: 14 additions & 8 deletions sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,11 @@ progress:
}

func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) {
if t.GetSpec().GetUrn() != urns.TransformParDo {
switch t.GetSpec().GetUrn() {
case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate:
// Intentionally empty since these are permitted to have side inputs.
default:
// Nothing else is allowed to have side inputs.
return nil, nil
}
// TODO, memoize this, so we don't need to repeatedly unmarshal.
Expand Down Expand Up @@ -334,6 +338,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
return col
}

// Update coders for Stateful transforms.
for _, tid := range stg.transforms {
t := comps.GetTransforms()[tid]

Expand Down Expand Up @@ -461,10 +466,11 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
}
// Update side inputs to point to new PCollection with any replaced coders.
transforms[si.Transform].GetInputs()[si.Local] = newGlobal
// TODO: replace si.Global with newGlobal?
}
prepSide, err := handleSideInput(si, comps, coders, em)
prepSide, err := handleSideInput(si, comps, transforms, pcollections, coders, em)
if err != nil {
slog.Error("buildDescriptor: handleSideInputs", err, slog.String("transformID", si.Transform))
slog.Error("buildDescriptor: handleSideInputs", "error", err, slog.String("transformID", si.Transform))
return err
}
prepareSides = append(prepareSides, prepSide)
Expand Down Expand Up @@ -556,8 +562,8 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
}

// handleSideInput returns a closure that will look up the data for a side input appropriate for the given watermark.
func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) {
t := comps.GetTransforms()[link.Transform]
func handleSideInput(link engine.LinkID, comps *pipepb.Components, transforms map[string]*pipepb.PTransform, pcols map[string]*pipepb.PCollection, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) {
t := transforms[link.Transform]
sis, err := getSideInputs(t)
if err != nil {
return nil, err
Expand All @@ -570,7 +576,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
slog.String("local", link.Local),
slog.String("global", link.Global))

col := comps.GetPcollections()[link.Global]
col := pcols[link.Global]
// The returned coders are unused here, but they add the side input coders
// to the stage components for use SDK side.

Expand All @@ -594,7 +600,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
slog.String("sourceTransform", t.GetUniqueName()),
slog.String("local", link.Local),
slog.String("global", link.Global))
col := comps.GetPcollections()[link.Global]
col := pcols[link.Global]

kvc := comps.GetCoders()[col.GetCoderId()]
if kvc.GetSpec().GetUrn() != urns.CoderKV {
Expand Down Expand Up @@ -633,7 +639,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
}] = windowed
}, nil
default:
return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, si.GetAccessPattern().GetUrn())
return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, prototext.Format(si.GetAccessPattern()))
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error {
// TODO: move data handling to be pcollection based.

key := req.GetStateKey()
slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b)
slog.Debug("StateRequest_Get", "request", prototext.Format(req), "bundle", b)
var data [][]byte
switch key.GetType().(type) {
case *fnpb.StateKey_IterableSideInput_:
Expand Down

0 comments on commit f6bfe45

Please sign in to comment.