diff --git a/internal/examples/helloworld/helloworld.go b/internal/examples/helloworld/helloworld.go index a5aa55f2..51c94723 100644 --- a/internal/examples/helloworld/helloworld.go +++ b/internal/examples/helloworld/helloworld.go @@ -9,6 +9,7 @@ import ( "fmt" "time" + "go.temporal.io/sdk/activity" "go.temporal.io/sdk/worker" "go.temporal.io/sdk/workflow" ) @@ -31,7 +32,12 @@ func PickGreeting(ctx context.Context) (string, error) { return "Hello", nil } +func TestIntercept(ctx context.Context) (string, error) { + return "Ok", nil +} + func RegisterWorkflowsAndActivities(r worker.Registry) { r.RegisterWorkflow(Greet) r.RegisterActivity(PickGreeting) + r.RegisterActivityWithOptions(TestIntercept, activity.RegisterOptions{Name: "TestIntercept"}) } diff --git a/internal/examples/helloworld/testinterceptor.go b/internal/examples/helloworld/testinterceptor.go new file mode 100644 index 00000000..29147ca7 --- /dev/null +++ b/internal/examples/helloworld/testinterceptor.go @@ -0,0 +1,57 @@ +package helloworld + +import ( + "time" + + "go.temporal.io/sdk/interceptor" + "go.temporal.io/sdk/workflow" +) + +var _ interceptor.Interceptor = &Interceptor{} + +type Interceptor struct { + interceptor.InterceptorBase +} + +type WorkflowInterceptor struct { + interceptor.WorkflowInboundInterceptorBase +} + +func NewTestInterceptor() *Interceptor { + return &Interceptor{} +} + +func (i *Interceptor) InterceptClient(next interceptor.ClientOutboundInterceptor) interceptor.ClientOutboundInterceptor { + return i.InterceptorBase.InterceptClient(next) +} + +func (i *Interceptor) InterceptWorkflow(ctx workflow.Context, next interceptor.WorkflowInboundInterceptor) interceptor.WorkflowInboundInterceptor { + return &WorkflowInterceptor{ + WorkflowInboundInterceptorBase: interceptor.WorkflowInboundInterceptorBase{ + Next: next, + }, + } +} + +func (i *WorkflowInterceptor) Init(outbound interceptor.WorkflowOutboundInterceptor) error { + return i.Next.Init(outbound) +} + +func (i *WorkflowInterceptor) ExecuteWorkflow(ctx workflow.Context, in *interceptor.ExecuteWorkflowInput) (interface{}, error) { + version := workflow.GetVersion(ctx, "version", workflow.DefaultVersion, 1) + var err error + + if version != workflow.DefaultVersion { + var vpt string + err = workflow.ExecuteLocalActivity( + workflow.WithLocalActivityOptions(ctx, workflow.LocalActivityOptions{ScheduleToCloseTimeout: time.Second}), + "TestIntercept", + ).Get(ctx, &vpt) + + if err != nil { + return nil, err + } + } + + return i.Next.ExecuteWorkflow(ctx, in) +} diff --git a/temporaltest/options.go b/temporaltest/options.go index 487b15e5..fdeb603c 100644 --- a/temporaltest/options.go +++ b/temporaltest/options.go @@ -8,6 +8,7 @@ import ( "testing" "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" "github.com/DataDog/temporalite" ) @@ -33,6 +34,17 @@ func WithBaseClientOptions(o client.Options) TestServerOption { }) } +// With WithBaseWorkerOptions configures default options for workers connected to the test server. +// +// WorkflowPanicPolicy is always set to worker.FailWorkflow so that workflow executions +// fail fast when workflow code panics or detects non-determinism. +func WithBaseWorkerOptions(o worker.Options) TestServerOption { + o.WorkflowPanicPolicy = worker.FailWorkflow + return newApplyFuncContainer(func(server *TestServer) { + server.defaultWorkerOptions = o + }) +} + // WithTemporaliteOptions provides the ability to use additional Temporalite options, including temporalite.WithUpstreamOptions. func WithTemporaliteOptions(options ...temporalite.ServerOption) TestServerOption { return newApplyFuncContainer(func(server *TestServer) { diff --git a/temporaltest/server.go b/temporaltest/server.go index 9519326b..93eabcf9 100644 --- a/temporaltest/server.go +++ b/temporaltest/server.go @@ -28,6 +28,7 @@ type TestServer struct { workers []worker.Worker t *testing.T defaultClientOptions client.Options + defaultWorkerOptions worker.Options serverOptions []temporalite.ServerOption } @@ -40,9 +41,25 @@ func (ts *TestServer) fatal(err error) { // Worker registers and starts a Temporal worker on the specified task queue. func (ts *TestServer) Worker(taskQueue string, registerFunc func(registry worker.Registry)) worker.Worker { - w := worker.New(ts.Client(), taskQueue, worker.Options{ - WorkflowPanicPolicy: worker.FailWorkflow, - }) + w := worker.New(ts.Client(), taskQueue, ts.defaultWorkerOptions) + registerFunc(w) + ts.workers = append(ts.workers, w) + + if err := w.Start(); err != nil { + ts.fatal(err) + } + + return w +} + +// NewWorkerWithOptions returns a Temporal worker on the specified task queue. +// +// WorkflowPanicPolicy is always set to worker.FailWorkflow so that workflow executions +// fail fast when workflow code panics or detects non-determinism. +func (ts *TestServer) NewWorkerWithOptions(taskQueue string, registerFunc func(registry worker.Registry), opts worker.Options) worker.Worker { + opts.WorkflowPanicPolicy = worker.FailWorkflow + + w := worker.New(ts.Client(), taskQueue, opts) registerFunc(w) ts.workers = append(ts.workers, w) diff --git a/temporaltest/server_test.go b/temporaltest/server_test.go index 7df05bd3..f657b8d4 100644 --- a/temporaltest/server_test.go +++ b/temporaltest/server_test.go @@ -82,6 +82,120 @@ func TestNewServer(t *testing.T) { } } +func TestNewWorkerWithOptions(t *testing.T) { + ts := temporaltest.NewServer(temporaltest.WithT(t)) + + ts.NewWorkerWithOptions( + "hello_world", + func(registry worker.Registry) { + helloworld.RegisterWorkflowsAndActivities(registry) + }, + worker.Options{ + MaxConcurrentActivityExecutionSize: 1, + MaxConcurrentLocalActivityExecutionSize: 1, + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + wfr, err := ts.Client().ExecuteWorkflow( + ctx, + client.StartWorkflowOptions{TaskQueue: "hello_world"}, + helloworld.Greet, + "world", + ) + if err != nil { + t.Fatal(err) + } + + var result string + if err := wfr.Get(ctx, &result); err != nil { + t.Fatal(err) + } + + if result != "Hello world" { + t.Fatalf("unexpected result: %q", result) + } + +} + +func TestDefaultWorkerOptions(t *testing.T) { + ts := temporaltest.NewServer( + temporaltest.WithT(t), + temporaltest.WithBaseWorkerOptions( + worker.Options{ + MaxConcurrentActivityExecutionSize: 1, + MaxConcurrentLocalActivityExecutionSize: 1, + }, + ), + ) + + ts.Worker("hello_world", func(registry worker.Registry) { + helloworld.RegisterWorkflowsAndActivities(registry) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + wfr, err := ts.Client().ExecuteWorkflow( + ctx, + client.StartWorkflowOptions{TaskQueue: "hello_world"}, + helloworld.Greet, + "world", + ) + if err != nil { + t.Fatal(err) + } + + var result string + if err := wfr.Get(ctx, &result); err != nil { + t.Fatal(err) + } + + if result != "Hello world" { + t.Fatalf("unexpected result: %q", result) + } +} + +func TestClientWithDefaultInterceptor(t *testing.T) { + var opts client.Options + opts.Interceptors = append(opts.Interceptors, helloworld.NewTestInterceptor()) + ts := temporaltest.NewServer( + temporaltest.WithT(t), + temporaltest.WithBaseClientOptions(opts), + ) + + ts.Worker( + "hello_world", + func(registry worker.Registry) { + helloworld.RegisterWorkflowsAndActivities(registry) + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + wfr, err := ts.Client().ExecuteWorkflow( + ctx, + client.StartWorkflowOptions{TaskQueue: "hello_world"}, + helloworld.Greet, + "world", + ) + if err != nil { + t.Fatal(err) + } + + var result string + if err := wfr.Get(ctx, &result); err != nil { + t.Fatal(err) + } + + if result != "Hello world" { + t.Fatalf("unexpected result: %q", result) + } +} + func BenchmarkRunWorkflow(b *testing.B) { ts := temporaltest.NewServer() defer ts.Stop()