Skip to content

Commit

Permalink
Ensure truncate element is wrapped in *FullValue (#25908)
Browse files Browse the repository at this point in the history
Co-authored-by: lostluck <[email protected]>
  • Loading branch information
lostluck and lostluck authored Mar 20, 2023
1 parent b5ce110 commit 6cb7b8e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
27 changes: 11 additions & 16 deletions sdks/go/pkg/beam/core/runtime/exec/sdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ func (n *TruncateSizedRestriction) StartBundle(ctx context.Context, id string, d
// Input Diagram:
//
// *FullValue {
// Elm: *FullValue {
// Elm: *FullValue (original input)
// Elm: *FullValue { -- mainElm
// Elm: *FullValue (original input) -- inp
// Elm2: *FullValue {
// Elm: Restriction
// Elm: Restriction -- rest
// Elm2: Watermark estimator state
// }
// }
Expand All @@ -325,24 +325,19 @@ func (n *TruncateSizedRestriction) StartBundle(ctx context.Context, id string, d
// }
func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error {
mainElm := elm.Elm.(*FullValue)
inp := mainElm.Elm
// For the main element, the way we fill it out depends on whether the input element
// is a KV or single-element. Single-elements might have been lifted out of
// their FullValue if they were decoded, so we need to have a case for that.
// TODO(https://github.com/apache/beam/issues/20196): Optimize this so it's decided in exec/translate.go
// instead of checking per-element.
if e, ok := mainElm.Elm.(*FullValue); ok {
mainElm = e
inp = e
}
rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm

// If receiving directly from a datasource,
// the element may not be wrapped in a *FullValue
inp := convertIfNeeded(mainElm.Elm, &FullValue{})

rest := mainElm.Elm2.(*FullValue).Elm

rt, err := n.ctInv.Invoke(ctx, rest)
if err != nil {
return err
}

newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm)
newRest, err := n.truncateInv.Invoke(ctx, rt, inp)
if err != nil {
return err
}
Expand All @@ -351,7 +346,7 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *Full
return nil
}

size, err := n.sizeInv.Invoke(ctx, mainElm, newRest)
size, err := n.sizeInv.Invoke(ctx, inp, newRest)
if err != nil {
return err
}
Expand Down
24 changes: 13 additions & 11 deletions sdks/go/test/integration/primitives/drain.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

func init() {
register.DoFn3x1[*sdf.LockRTracker, []byte, func(int64), sdf.ProcessContinuation](&TruncateFn{})
register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte, func(int64), sdf.ProcessContinuation](&TruncateFn{})

register.Emitter1[int64]()
}
Expand Down Expand Up @@ -83,39 +83,41 @@ func (fn *TruncateFn) SplitRestriction(_ []byte, rest offsetrange.Restriction) [
}

// TruncateRestriction truncates the restriction during drain.
func (fn *TruncateFn) TruncateRestriction(rt *sdf.LockRTracker, _ []byte) offsetrange.Restriction {
start := rt.GetRestriction().(offsetrange.Restriction).Start
func (fn *TruncateFn) TruncateRestriction(ctx context.Context, rt *sdf.LockRTracker, _ []byte) offsetrange.Restriction {
rest := rt.GetRestriction().(offsetrange.Restriction)
start := rest.Start
newEnd := start + 20

done, remaining := rt.GetProgress()
log.Infof(ctx, "Draining at: done %v, remaining %v, start %v, end %v, newEnd %v", done, remaining, start, rest.End, newEnd)

return offsetrange.Restriction{
Start: start,
End: newEnd,
}
}

// ProcessElement continually gets the start position of the restriction and emits the element as it is.
func (fn *TruncateFn) ProcessElement(rt *sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation {
func (fn *TruncateFn) ProcessElement(ctx context.Context, rt *sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation {
position := rt.GetRestriction().(offsetrange.Restriction).Start
counter := 0
for {
if rt.TryClaim(position) {
log.Infof(ctx, "Claimed position: %v", position)
// Successful claim, emit the value and move on.
emit(position)
position++
counter++
} else if rt.GetError() != nil || rt.IsDone() {
// Stop processing on error or completion
if err := rt.GetError(); err != nil {
log.Errorf(context.Background(), "error in restriction tracker, got %v", err)
log.Errorf(ctx, "error in restriction tracker, got %v", err)
}
log.Infof(ctx, "Restriction done at position %v.", position)
return sdf.StopProcessing()
} else {
log.Infof(ctx, "Checkpointed at position %v, resuming later.", position)
// Resume later.
return sdf.ResumeProcessingIn(5 * time.Second)
}

if counter >= 10 {
return sdf.ResumeProcessingIn(1 * time.Second)
}
time.Sleep(1 * time.Second)
}
}
Expand Down

0 comments on commit 6cb7b8e

Please sign in to comment.