Skip to content

Commit

Permalink
Validate workflow task start time when complete (#4663)
Browse files Browse the repository at this point in the history
<!-- Describe what has changed in this PR -->
**What changed?**
Add start time to workflow task token and validate it on close.

<!-- Tell your future self why have you made these changes -->
**Why?**
To reject concurrent speculative workflow task with same startedEventID.

<!-- How have you verified this change? Tested locally? Added a unit
test? Checked in staging env? -->
**How did you test it?**
Integration tests.

<!-- Assuming the worst case, what can be broken when deploying this
change to production? -->
**Potential risks**
No

<!-- Is this PR a hotfix candidate or require that a notification be
sent to the broader community? (Yes/No) -->
**Is hotfix candidate?**
Yes
  • Loading branch information
yiminc authored and yycptt committed Jul 21, 2023
1 parent 655c51a commit 1d5601a
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 752 deletions.
211 changes: 164 additions & 47 deletions api/token/v1/message.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions proto/internal/temporal/server/api/token/v1/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package temporal.server.api.token.v1;

option go_package = "go.temporal.io/server/api/token/v1;token";

import "google/protobuf/timestamp.proto";
import "dependencies/gogoproto/gogo.proto";
import "temporal/server/api/clock/v1/message.proto";
import "temporal/server/api/history/v1/message.proto";

Expand Down Expand Up @@ -62,6 +64,8 @@ message Task {
string activity_type = 8;
temporal.server.api.clock.v1.VectorClock clock = 9;
int64 started_event_id = 10;
int64 version = 11;
google.protobuf.Timestamp started_time = 12 [(gogoproto.stdtime) = true];
}

message QueryTask {
Expand Down
1 change: 1 addition & 0 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(
RunId: taskToken.GetRunId(),
ScheduledEventId: histResp.StartedResponse.GetScheduledEventId(),
StartedEventId: histResp.StartedResponse.GetStartedEventId(),
StartedTime: histResp.StartedResponse.GetStartedTime(),
Attempt: histResp.StartedResponse.GetAttempt(),
}
token, err := wh.tokenSerializer.Serialize(taskToken)
Expand Down
3 changes: 2 additions & 1 deletion service/history/api/startworkflow/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import (
historypb "go.temporal.io/api/history/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/api/historyservice/v1"

tokenspb "go.temporal.io/server/api/token/v1"

"go.temporal.io/server/api/historyservice/v1"
"go.temporal.io/server/common"
"go.temporal.io/server/common/definition"
"go.temporal.io/server/common/metrics"
Expand Down Expand Up @@ -542,6 +542,7 @@ func (s *Starter) generateResponse(
RunId: runID,
ScheduledEventId: workflowTaskInfo.ScheduledEventID,
StartedEventId: workflowTaskInfo.StartedEventID,
StartedTime: workflowTaskInfo.StartedTime,
Attempt: workflowTaskInfo.Attempt,
Clock: clock,
}
Expand Down
34 changes: 22 additions & 12 deletions service/history/workflowTaskHandlerCallbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,13 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskFailed(
scheduledEventID := token.GetScheduledEventId()
workflowTask := mutableState.GetWorkflowTaskByID(scheduledEventID)

if workflowTask == nil || workflowTask.Attempt != token.Attempt || workflowTask.StartedEventID == common.EmptyEventID ||
(token.StartedEventId != common.EmptyEventID && token.StartedEventId != workflowTask.StartedEventID) {
if workflowTask == nil ||
workflowTask.StartedEventID == common.EmptyEventID ||
(token.StartedEventId != common.EmptyEventID && token.StartedEventId != workflowTask.StartedEventID) ||
(token.StartedTime != nil && workflowTask.StartedTime != nil && !token.StartedTime.Equal(*workflowTask.StartedTime)) ||
workflowTask.Attempt != token.Attempt {
// we have not alter mutable state yet, so release with it with nil to avoid clear MS.
workflowContext.GetReleaseFn()(nil)
return nil, serviceerror.NewNotFound("Workflow task not found.")
}

Expand Down Expand Up @@ -390,6 +395,21 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
if err != nil {
return nil, err
}
weContext := workflowContext.GetContext()
ms := workflowContext.GetMutableState()

currentWorkflowTask := ms.GetWorkflowTaskByID(token.GetScheduledEventId())
if !ms.IsWorkflowExecutionRunning() ||
currentWorkflowTask == nil ||
currentWorkflowTask.StartedEventID == common.EmptyEventID ||
(token.StartedEventId != common.EmptyEventID && token.StartedEventId != currentWorkflowTask.StartedEventID) ||
(token.StartedTime != nil && currentWorkflowTask.StartedTime != nil && !token.StartedTime.Equal(*currentWorkflowTask.StartedTime)) ||
currentWorkflowTask.Attempt != token.Attempt {
// we have not alter mutable state yet, so release with it with nil to avoid clear MS.
workflowContext.GetReleaseFn()(nil)
return nil, serviceerror.NewNotFound("Workflow task not found.")
}

defer func() { workflowContext.GetReleaseFn()(retError) }()

var effects effect.Buffer
Expand All @@ -407,16 +427,6 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted(
effects.Apply(ctx)
}()

weContext := workflowContext.GetContext()
ms := workflowContext.GetMutableState()

currentWorkflowTask := ms.GetWorkflowTaskByID(token.GetScheduledEventId())
if !ms.IsWorkflowExecutionRunning() || currentWorkflowTask == nil || currentWorkflowTask.Attempt != token.Attempt ||
currentWorkflowTask.StartedEventID == common.EmptyEventID ||
(token.StartedEventId != common.EmptyEventID && token.StartedEventId != currentWorkflowTask.StartedEventID) {
return nil, serviceerror.NewNotFound("Workflow task not found.")
}

// It's an error if the workflow has used versioning in the past but this task has no versioning info.
if ms.GetWorkerVersionStamp().GetUseVersioning() && !request.GetWorkerVersionStamp().GetUseVersioning() {
return nil, serviceerror.NewInvalidArgument("Workflow using versioning must continue to use versioning.")
Expand Down
Loading

0 comments on commit 1d5601a

Please sign in to comment.