Skip to content

Commit

Permalink
[prism] Fail jobs on SDK disconnect. (#28193)
Browse files Browse the repository at this point in the history
* [prism] Fail jobs on SDK disconnect.

* Reduce flaky short bame for passert test.

* [prism] better workerID, warn on pre-bundle fail, buffer done chan

* Add causes, extract bundle failures to RunPipeline

* Return bundle errors through execPipeline.

---------

Co-authored-by: lostluck <[email protected]>
  • Loading branch information
lostluck and lostluck authored Sep 1, 2023
1 parent 06cdd5e commit 3463aa3
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 86 deletions.
17 changes: 14 additions & 3 deletions sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ func (rb RunBundle) LogValue() slog.Value {
// remaining.
func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) <-chan RunBundle {
runStageCh := make(chan RunBundle)
ctx, cancelFn := context.WithCancel(ctx)
ctx, cancelFn := context.WithCancelCause(ctx)
go func() {
em.pendingElements.Wait()
slog.Info("no more pending elements: terminating pipeline")
cancelFn()
slog.Debug("no more pending elements: terminating pipeline")
cancelFn(fmt.Errorf("elementManager out of elements, cleaning up"))
// Ensure the watermark evaluation goroutine exits.
em.refreshCond.Broadcast()
}()
Expand Down Expand Up @@ -394,6 +394,17 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol
em.addRefreshAndClearBundle(stage.ID, rb.BundleID)
}

// FailBundle clears the extant data allowing the execution to shut down.
func (em *ElementManager) FailBundle(rb RunBundle) {
stage := em.stages[rb.StageID]
stage.mu.Lock()
completed := stage.inprogress[rb.BundleID]
em.pendingElements.Add(-len(completed.es))
delete(stage.inprogress, rb.BundleID)
stage.mu.Unlock()
em.addRefreshAndClearBundle(rb.StageID, rb.BundleID)
}

// ReturnResiduals is called after a successful split, so the remaining work
// can be re-assigned to a new bundle.
func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, inputInfo PColInfo, residuals [][]byte) {
Expand Down
43 changes: 29 additions & 14 deletions sdks/go/pkg/beam/runners/prism/internal/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func RunPipeline(j *jobservices.Job) {
return
}
env, _ := getOnlyPair(envs)
wk := worker.New(env) // Cheating by having the worker id match the environment id.
wk := worker.New(j.String()+"_"+env, env) // Cheating by having the worker id match the environment id.
go wk.Serve()
timeout := time.Minute
time.AfterFunc(timeout, func() {
Expand All @@ -69,7 +69,7 @@ func RunPipeline(j *jobservices.Job) {
// When this function exits, we cancel the context to clear
// any related job resources.
defer func() {
j.CancelFn(nil)
j.CancelFn(fmt.Errorf("runPipeline returned, cleaning up"))
}()
go runEnvironment(j.RootCtx, j, env, wk)

Expand Down Expand Up @@ -102,10 +102,10 @@ func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *wor
case urns.EnvExternal:
ep := &pipepb.ExternalPayload{}
if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), ep); err != nil {
slog.Error("unmarshing environment payload", err, slog.String("envID", wk.ID))
slog.Error("unmarshing environment payload", err, slog.String("envID", wk.Env))
}
externalEnvironment(ctx, ep, wk)
slog.Info("environment stopped", slog.String("envID", wk.String()), slog.String("job", j.String()))
slog.Debug("environment stopped", slog.String("envID", wk.String()), slog.String("job", j.String()))
default:
panic(fmt.Sprintf("environment %v with urn %v unimplemented", env, e.GetUrn()))
}
Expand Down Expand Up @@ -271,7 +271,7 @@ func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) erro
}
stages[stage.ID] = stage
wk.Descriptors[stage.ID] = stage.desc
case wk.ID:
case wk.Env:
// Great! this is for this environment. // Broken abstraction.
if err := buildDescriptor(stage, comps, wk); err != nil {
return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err)
Expand All @@ -296,16 +296,31 @@ func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) erro
// Use a channel to limit max parallelism for the pipeline.
maxParallelism := make(chan struct{}, 8)
// Execute stages here
for rb := range em.Bundles(ctx, wk.NextInst) {
maxParallelism <- struct{}{}
go func(rb engine.RunBundle) {
defer func() { <-maxParallelism }()
s := stages[rb.StageID]
s.Execute(ctx, j, wk, comps, em, rb)
}(rb)
bundleFailed := make(chan error)
bundles := em.Bundles(ctx, wk.NextInst)
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case rb, ok := <-bundles:
if !ok {
slog.Debug("pipeline done!", slog.String("job", j.String()))
return nil
}
maxParallelism <- struct{}{}
go func(rb engine.RunBundle) {
defer func() { <-maxParallelism }()
s := stages[rb.StageID]
if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil {
// Ensure we clean up on bundle failure
em.FailBundle(rb)
bundleFailed <- err
}
}(rb)
case err := <-bundleFailed:
return err
}
}
slog.Info("pipeline done!", slog.String("job", j.String()))
return nil
}

func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, comps *pipepb.Components) func(io.Reader) []byte {
Expand Down
12 changes: 8 additions & 4 deletions sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,13 @@ func (j *Job) SendMsg(msg string) {
func (j *Job) sendState(state jobpb.JobState_Enum) {
j.streamCond.L.Lock()
defer j.streamCond.L.Unlock()
j.stateTime = time.Now()
j.stateIdx++
j.state.Store(state)
old := j.state.Load()
// Never overwrite a failed state with another one.
if old != jobpb.JobState_FAILED {
j.state.Store(state)
j.stateTime = time.Now()
j.stateIdx++
}
j.streamCond.Broadcast()
}

Expand All @@ -163,5 +167,5 @@ func (j *Job) Failed(err error) {
slog.Error("job failed", slog.Any("job", j), slog.Any("error", err))
j.failureErr = err
j.sendState(jobpb.JobState_FAILED)
j.CancelFn(err)
j.CancelFn(fmt.Errorf("jobFailed %v: %w", j, err))
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.Jo
job.streamCond.Wait()
select { // Quit out if the external connection is done.
case <-stream.Context().Done():
return stream.Context().Err()
return context.Cause(stream.Context())
default:
}
}
Expand Down
31 changes: 12 additions & 19 deletions sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,7 @@ type stage struct {
OutputsToCoders map[string]engine.PColInfo
}

func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) {
select {
case <-ctx.Done():
return
default:
}
func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) error {
slog.Debug("Execute: starting bundle", "bundle", rb)

var b *worker.B
Expand All @@ -103,7 +98,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c
closed := make(chan struct{})
close(closed)
dataReady = closed
case wk.ID:
case wk.Env:
b = &worker.B{
PBDID: s.ID,
InstID: rb.BundleID,
Expand All @@ -122,15 +117,10 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c

slog.Debug("Execute: processing", "bundle", rb)
defer b.Cleanup(wk)
b.Fail = func(errMsg string) {
slog.Error("job failed", "bundle", rb, "job", j)
err := fmt.Errorf("%v", errMsg)
j.Failed(err)
}
dataReady = b.ProcessOn(ctx, wk)
default:
err := fmt.Errorf("unknown environment[%v]", s.envID)
slog.Error("Execute", err)
slog.Error("Execute", "error", err)
panic(err)
}

Expand All @@ -145,20 +135,20 @@ progress:
progTick.Stop()
break progress // exit progress loop on close.
case <-progTick.C:
resp, err := b.Progress(wk)
resp, err := b.Progress(ctx, wk)
if err != nil {
slog.Debug("SDK Error from progress, aborting progress", "bundle", rb, "error", err.Error())
break progress
}
index, unknownIDs := j.ContributeTentativeMetrics(resp)
if len(unknownIDs) > 0 {
md := wk.MonitoringMetadata(unknownIDs)
md := wk.MonitoringMetadata(ctx, unknownIDs)
j.AddMetricShortIDs(md)
}
slog.Debug("progress report", "bundle", rb, "index", index)
// Progress for the bundle hasn't advanced. Try splitting.
if previousIndex == index && !splitsDone {
sr, err := b.Split(wk, 0.5 /* fraction of remainder */, nil /* allowed splits */)
sr, err := b.Split(ctx, wk, 0.5 /* fraction of remainder */, nil /* allowed splits */)
if err != nil {
slog.Warn("SDK Error from split, aborting splits", "bundle", rb, "error", err.Error())
break progress
Expand Down Expand Up @@ -200,16 +190,18 @@ progress:
var resp *fnpb.ProcessBundleResponse
select {
case resp = <-b.Resp:
if b.BundleErr != nil {
return b.BundleErr
}
case <-ctx.Done():
// Ensures we clean up on failure, if the response is blocked.
return
return context.Cause(ctx)
}

// Tally metrics immeadiately so they're available before
// pipeline termination.
unknownIDs := j.ContributeFinalMetrics(resp)
if len(unknownIDs) > 0 {
md := wk.MonitoringMetadata(unknownIDs)
md := wk.MonitoringMetadata(ctx, unknownIDs)
j.AddMetricShortIDs(md)
}
// TODO handle side input data properly.
Expand Down Expand Up @@ -239,6 +231,7 @@ progress:
}
em.PersistBundle(rb, s.OutputsToCoders, b.OutputData, s.inputInfo, residualData, minOutputWatermark)
b.OutputData = engine.TentativeData{} // Clear the data.
return nil
}

func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) {
Expand Down
24 changes: 13 additions & 11 deletions sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,11 @@ type B struct {
dataSema atomic.Int32
OutputData engine.TentativeData

// TODO move response channel to an atomic and an additional
// block on the DataWait channel, to allow progress & splits for
// no output DoFns.
Resp chan *fnpb.ProcessBundleResponse
Resp chan *fnpb.ProcessBundleResponse
BundleErr error
responded bool

SinkToPCollection map[string]string

Fail func(err string) // Called if bundle returns an error.
}

// Init initializes the bundle's internal state for waiting on all
Expand Down Expand Up @@ -90,8 +87,13 @@ func (b *B) LogValue() slog.Value {
}

func (b *B) Respond(resp *fnpb.InstructionResponse) {
if b.responded {
slog.Warn("additional bundle response", "bundle", b, "resp", resp)
return
}
b.responded = true
if resp.GetError() != "" {
b.Fail(resp.GetError())
b.BundleErr = fmt.Errorf("bundle %v failed:%v", resp.GetInstructionId(), resp.GetError())
close(b.Resp)
return
}
Expand Down Expand Up @@ -152,8 +154,8 @@ func (b *B) Cleanup(wk *W) {
}

// Progress sends a progress request for the given bundle to the passed in worker, blocking on the response.
func (b *B) Progress(wk *W) (*fnpb.ProcessBundleProgressResponse, error) {
resp := wk.sendInstruction(&fnpb.InstructionRequest{
func (b *B) Progress(ctx context.Context, wk *W) (*fnpb.ProcessBundleProgressResponse, error) {
resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
Request: &fnpb.InstructionRequest_ProcessBundleProgress{
ProcessBundleProgress: &fnpb.ProcessBundleProgressRequest{
InstructionId: b.InstID,
Expand All @@ -167,8 +169,8 @@ func (b *B) Progress(wk *W) (*fnpb.ProcessBundleProgressResponse, error) {
}

// Split sends a split request for the given bundle to the passed in worker, blocking on the response.
func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) (*fnpb.ProcessBundleSplitResponse, error) {
resp := wk.sendInstruction(&fnpb.InstructionRequest{
func (b *B) Split(ctx context.Context, wk *W, fraction float64, allowedSplits []int64) (*fnpb.ProcessBundleSplitResponse, error) {
resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
Request: &fnpb.InstructionRequest_ProcessBundleSplit{
ProcessBundleSplit: &fnpb.ProcessBundleSplitRequest{
InstructionId: b.InstID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func TestBundle_ProcessOn(t *testing.T) {
wk := New("test")
wk := New("test", "testEnv")
b := &B{
InstID: "testInst",
PBDID: "testPBDID",
Expand Down
Loading

0 comments on commit 3463aa3

Please sign in to comment.