Skip to content

Commit

Permalink
Allow completing a Nexus operation after workflow reset (#6434)
Browse files Browse the repository at this point in the history
## What changed?

Added a request ID to the nexus completion token that can be used to
check a completion against a state machine in a workflow post reset,
e.g. with a new run ID.

## Why?

Make reset safer.

## How did you test it?

Added a functional test.
  • Loading branch information
bergundy authored Aug 23, 2024
1 parent 13d6cd8 commit 34e628e
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 115 deletions.
232 changes: 122 additions & 110 deletions api/token/v1/message.pb.go

Large diffs are not rendered by default.

24 changes: 19 additions & 5 deletions components/nexusoperations/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ import (
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
historypb "go.temporal.io/api/history/v1"
"go.temporal.io/api/serviceerror"
commonnexus "go.temporal.io/server/common/nexus"
"go.temporal.io/server/service/history/hsm"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func handleSuccessfulOperationResult(
Expand Down Expand Up @@ -132,23 +131,38 @@ func CompletionHandler(
ctx context.Context,
env hsm.Environment,
ref hsm.Ref,
requestID string,
result *commonpb.Payload,
opFailedError *nexus.UnsuccessfulOperationError,
) error {
return env.Access(ctx, ref, hsm.AccessWrite, func(node *hsm.Node) error {
// The initial version of the completion token did not include a request ID.
// Only retry Access without a run ID if the request ID is not empty.
isRetryableNotFoundErr := requestID != ""
err := env.Access(ctx, ref, hsm.AccessWrite, func(node *hsm.Node) error {
if err := node.CheckRunning(); err != nil {
return status.Errorf(codes.NotFound, "operation not found")
return serviceerror.NewNotFound("operation not found")
}
err := hsm.MachineTransition(node, func(operation Operation) (hsm.TransitionOutput, error) {
if requestID != "" && operation.RequestId != requestID {
isRetryableNotFoundErr = false
return hsm.TransitionOutput{}, serviceerror.NewNotFound("operation not found")
}
if opFailedError != nil {
return handleUnsuccessfulOperationError(node, operation, opFailedError, CompletionSourceCallback)
}
return handleSuccessfulOperationResult(node, operation, result, CompletionSourceCallback)
})
// TODO(bergundy): Remove this once the operation auto-deletes itself from the tree on completion.
if errors.Is(err, hsm.ErrInvalidTransition) {
return status.Errorf(codes.NotFound, "operation not found")
isRetryableNotFoundErr = false
return serviceerror.NewNotFound("operation not found")
}
return err
})
if errors.As(err, new(*serviceerror.NotFound)) && isRetryableNotFoundErr && ref.WorkflowKey.RunID != "" {
// Try again without a run ID in case the original run was reset.
ref.WorkflowKey.RunID = ""
return CompletionHandler(ctx, env, ref, requestID, result, opFailedError)
}
return err
}
1 change: 1 addition & 0 deletions components/nexusoperations/executors.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ
WorkflowId: ref.WorkflowKey.WorkflowID,
RunId: ref.WorkflowKey.RunID,
Ref: smRef,
RequestId: args.requestID,
})
if err != nil {
return fmt.Errorf("%w: %w", queues.NewUnprocessableTaskError("failed to generate a callback token"), err)
Expand Down
3 changes: 3 additions & 0 deletions proto/internal/temporal/server/api/token/v1/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,7 @@ message NexusOperationCompletion {
// Reference including the path to the backing Operation state machine and a version + transition count for
// staleness checks.
temporal.server.api.persistence.v1.StateMachineRef ref = 4;
// Request ID embedded in the NexusOperationScheduledEvent.
// Allows completing a started operation after a workflow has been reset.
string request_id = 5;
}
1 change: 1 addition & 0 deletions service/history/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,7 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse
ctx,
engine.StateMachineEnvironment(),
ref,
request.Completion.RequestId,
request.GetSuccess(),
opErr,
)
Expand Down
158 changes: 158 additions & 0 deletions tests/nexus_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,164 @@ func (s *ClientFunctionalSuite) TestNexusOperationAsyncCompletionInternalAuth()
s.Equal("result", result)
}

func (s *ClientFunctionalSuite) TestNexusOperationAsyncCompletionAfterReset() {
ctx := NewContext()
taskQueue := s.randomizeStr(s.T().Name())
endpointName := RandomizedNexusEndpoint(s.T().Name())

var callbackToken, publicCallbackUrl string

h := nexustest.Handler{
OnStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) {
callbackToken = options.CallbackHeader.Get(commonnexus.CallbackTokenHeader)
publicCallbackUrl = options.CallbackURL
return &nexus.HandlerStartOperationResultAsync{OperationID: "test"}, nil
},
}
listenAddr := nexustest.AllocListenAddress(s.T())
nexustest.NewNexusServer(s.T(), listenAddr, h)

_, err := s.operatorClient.CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{
Spec: &nexuspb.EndpointSpec{
Name: endpointName,
Target: &nexuspb.EndpointTarget{
Variant: &nexuspb.EndpointTarget_External_{
External: &nexuspb.EndpointTarget_External{
Url: "http://" + listenAddr,
},
},
},
},
})
s.NoError(err)

run, err := s.sdkClient.ExecuteWorkflow(ctx, client.StartWorkflowOptions{
TaskQueue: taskQueue,
}, "workflow")
s.NoError(err)

pollResp, err := s.client.PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{
Namespace: s.namespace,
TaskQueue: &taskqueue.TaskQueue{
Name: taskQueue,
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
},
Identity: "test",
})
s.NoError(err)
_, err = s.client.RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{
Identity: "test",
TaskToken: pollResp.TaskToken,
Commands: []*commandpb.Command{
{
CommandType: enumspb.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION,
Attributes: &commandpb.Command_ScheduleNexusOperationCommandAttributes{
ScheduleNexusOperationCommandAttributes: &commandpb.ScheduleNexusOperationCommandAttributes{
Endpoint: endpointName,
Service: "service",
Operation: "operation",
Input: s.mustToPayload("input"),
},
},
},
},
})
s.NoError(err)

// Poll and verify that the "started" event was recorded.
pollResp, err = s.client.PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{
Namespace: s.namespace,
TaskQueue: &taskqueue.TaskQueue{
Name: taskQueue,
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
},
Identity: "test",
})
s.NoError(err)
_, err = s.client.RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{
Identity: "test",
TaskToken: pollResp.TaskToken,
})
s.NoError(err)

startedEventIdx := slices.IndexFunc(pollResp.History.Events, func(e *historypb.HistoryEvent) bool {
return e.GetNexusOperationStartedEventAttributes() != nil
})
s.Greater(startedEventIdx, 0)

// Remember the workflow task completed event ID (next after the last WFT started), we'll use it to test reset
// below.
wftCompletedEventID := int64(len(pollResp.History.Events))

// Reset the workflow and check that the started event has been reapplied.
resetResp, err := s.client.ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{
Namespace: s.namespace,
WorkflowExecution: pollResp.WorkflowExecution,
Reason: "test",
RequestId: uuid.NewString(),
WorkflowTaskFinishEventId: wftCompletedEventID,
})
s.NoError(err)

hist := s.sdkClient.GetWorkflowHistory(ctx, run.GetID(), resetResp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT)

seenCompletedEvent := false
for hist.HasNext() {
event, err := hist.Next()
s.NoError(err)
if event.EventType == enumspb.EVENT_TYPE_NEXUS_OPERATION_STARTED {
seenCompletedEvent = true
}
}
s.True(seenCompletedEvent)
completion, err := nexus.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexus.OperationCompletionSuccesfulOptions{
Serializer: commonnexus.PayloadSerializer,
})
s.NoError(err)

res, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackUrl, completion, callbackToken)
s.Equal(http.StatusOK, res.StatusCode)

// Poll again and verify the completion is recorded and triggers workflow progress.
pollResp, err = s.client.PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{
Namespace: s.namespace,
TaskQueue: &taskqueue.TaskQueue{
Name: taskQueue,
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
},
Identity: "test",
})
s.NoError(err)
completedEventIdx := slices.IndexFunc(pollResp.History.Events, func(e *historypb.HistoryEvent) bool {
return e.GetNexusOperationCompletedEventAttributes() != nil
})
s.Greater(completedEventIdx, 0)

_, err = s.client.RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{
Identity: "test",
TaskToken: pollResp.TaskToken,
Commands: []*commandpb.Command{
{
CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION,
Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{
CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{
Result: &commonpb.Payloads{
Payloads: []*commonpb.Payload{
pollResp.History.Events[completedEventIdx].GetNexusOperationCompletedEventAttributes().Result,
},
},
},
},
},
},
})
s.NoError(err)
var result string
run = s.sdkClient.GetWorkflow(ctx, run.GetID(), resetResp.RunId)
s.NoError(run.Get(ctx, &result))
s.Equal("result", result)
}

func (s *FunctionalTestBase) sendNexusCompletionRequest(
ctx context.Context,
t *testing.T,
Expand Down

0 comments on commit 34e628e

Please sign in to comment.