Skip to content

Commit

Permalink
Access to workflow/activity instance from context (temporalio#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
Quinn-With-Two-Ns authored Jan 22, 2025
1 parent 3ad0b0e commit b471e13
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,7 @@ public interface ActivityExecutionContext {
* an activity.
*/
WorkflowClient getWorkflowClient();

/** Get the currently running activity instance. */
Object getInstance();
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,9 @@ public Scope getMetricsScope() {
public WorkflowClient getWorkflowClient() {
return next.getWorkflowClient();
}

@Override
public Object getInstance() {
return next.getInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
import com.uber.m3.tally.Scope;

public interface ActivityExecutionContextFactory {
InternalActivityExecutionContext createContext(ActivityInfoInternal info, Scope metricsScope);
InternalActivityExecutionContext createContext(
ActivityInfoInternal info, Object activity, Scope metricsScope);
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ public ActivityExecutionContextFactoryImpl(

@Override
public InternalActivityExecutionContext createContext(
ActivityInfoInternal info, Scope metricsScope) {
ActivityInfoInternal info, Object activity, Scope metricsScope) {
return new ActivityExecutionContextImpl(
client,
namespace,
activity,
info,
dataConverter,
heartbeatExecutor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
class ActivityExecutionContextImpl implements InternalActivityExecutionContext {
private final Lock lock = new ReentrantLock();
private final WorkflowClient client;
private final Object activity;
private final ManualActivityCompletionClientFactory manualCompletionClientFactory;
private final Functions.Proc completionHandle;
private final HeartbeatContext heartbeatContext;
Expand All @@ -61,6 +62,7 @@ class ActivityExecutionContextImpl implements InternalActivityExecutionContext {
ActivityExecutionContextImpl(
WorkflowClient client,
String namespace,
Object activity,
ActivityInfo info,
DataConverter dataConverter,
ScheduledExecutorService heartbeatExecutor,
Expand All @@ -71,6 +73,7 @@ class ActivityExecutionContextImpl implements InternalActivityExecutionContext {
Duration maxHeartbeatThrottleInterval,
Duration defaultHeartbeatThrottleInterval) {
this.client = client;
this.activity = activity;
this.metricsScope = metricsScope;
this.info = info;
this.completionHandle = completionHandle;
Expand Down Expand Up @@ -177,4 +180,9 @@ public Object getLastHeartbeatValue() {
public WorkflowClient getWorkflowClient() {
return client;
}

@Override
public Object getInstance() {
return activity;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public BaseActivityTaskExecutor(
@Override
public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metricsScope) {
InternalActivityExecutionContext context =
executionContextFactory.createContext(info, metricsScope);
executionContextFactory.createContext(info, getActivity(), metricsScope);
ActivityInfo activityInfo = context.getInfo();
ActivitySerializationContext serializationContext =
new ActivitySerializationContext(
Expand Down Expand Up @@ -144,6 +144,8 @@ public ActivityTaskHandler.Result execute(ActivityInfoInternal info, Scope metri

abstract ActivityInboundCallsInterceptor createRootInboundInterceptor();

abstract Object getActivity();

abstract Object[] provideArgs(
Optional<Payloads> input, DataConverter dataConverterWithActivityContext);

Expand Down Expand Up @@ -203,6 +205,11 @@ ActivityInboundCallsInterceptor createRootInboundInterceptor() {
activity, method);
}

@Override
Object getActivity() {
return activity;
}

@Override
Object[] provideArgs(Optional<Payloads> input, DataConverter dataConverterWithActivityContext) {
return dataConverterWithActivityContext.fromPayloads(
Expand Down Expand Up @@ -241,6 +248,11 @@ ActivityInboundCallsInterceptor createRootInboundInterceptor() {
activity);
}

@Override
Object getActivity() {
return activity;
}

@Override
Object[] provideArgs(Optional<Payloads> input, DataConverter dataConverterWithActivityContext) {
EncodedValues encodedValues = new EncodedValues(input, dataConverterWithActivityContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public LocalActivityExecutionContextFactoryImpl(WorkflowClient client) {

@Override
public InternalActivityExecutionContext createContext(
ActivityInfoInternal info, Scope metricsScope) {
return new LocalActivityExecutionContextImpl(client, info, metricsScope);
ActivityInfoInternal info, Object activity, Scope metricsScope) {
return new LocalActivityExecutionContextImpl(client, activity, info, metricsScope);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@

class LocalActivityExecutionContextImpl implements InternalActivityExecutionContext {
private final WorkflowClient client;
private final Object activity;
private final ActivityInfo info;
private final Scope metricsScope;

LocalActivityExecutionContextImpl(WorkflowClient client, ActivityInfo info, Scope metricsScope) {
LocalActivityExecutionContextImpl(
WorkflowClient client, Object activity, ActivityInfo info, Scope metricsScope) {
this.client = client;
this.activity = activity;
this.info = info;
this.metricsScope = metricsScope;
}
Expand Down Expand Up @@ -100,4 +103,9 @@ public Object getLastHeartbeatValue() {
public WorkflowClient getWorkflowClient() {
return client;
}

@Override
public Object getInstance() {
return activity;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptor;
import io.temporal.workflow.DynamicWorkflow;
import io.temporal.workflow.Functions;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nullable;

final class DynamicSyncWorkflowDefinition implements SyncWorkflowDefinition {

private final Functions.Func1<EncodedValues, ? extends DynamicWorkflow> factory;
private RootWorkflowInboundCallsInterceptor rootWorkflowInvoker;
private final WorkerInterceptor[] workerInterceptors;
// don't pass it down to other classes, it's a "cached" instance for internal usage only
private final DataConverter dataConverterWithWorkflowContext;
Expand All @@ -52,7 +55,9 @@ public DynamicSyncWorkflowDefinition(
@Override
public void initialize(Optional<Payloads> input) {
SyncWorkflowContext workflowContext = WorkflowInternal.getRootWorkflowContext();
workflowInvoker = new RootWorkflowInboundCallsInterceptor(workflowContext, input);
RootWorkflowInboundCallsInterceptor rootWorkflowInvoker =
new RootWorkflowInboundCallsInterceptor(workflowContext, input);
workflowInvoker = rootWorkflowInvoker;
for (WorkerInterceptor workerInterceptor : workerInterceptors) {
workflowInvoker = workerInterceptor.interceptWorkflow(workflowInvoker);
}
Expand All @@ -69,6 +74,13 @@ public Optional<Payloads> execute(Header header, Optional<Payloads> input) {
return dataConverterWithWorkflowContext.toPayloads(result.getResult());
}

@Nullable
@Override
public Object getInstance() {
Objects.requireNonNull(rootWorkflowInvoker, "getInstance called before initialize.");
return rootWorkflowInvoker.getInstance();
}

class RootWorkflowInboundCallsInterceptor extends BaseRootWorkflowInboundCallsInterceptor {
private DynamicWorkflow workflow;
private Optional<Payloads> input;
Expand All @@ -79,6 +91,10 @@ public RootWorkflowInboundCallsInterceptor(
this.input = input;
}

public DynamicWorkflow getInstance() {
return workflow;
}

@Override
public void init(WorkflowOutboundCallsInterceptor outboundCalls) {
super.init(outboundCalls);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -315,6 +316,7 @@ private class POJOWorkflowImplementation implements SyncWorkflowDefinition {
private final Class<?> workflowImplementationClass;
private final Method workflowMethod;
private final Constructor<?> ctor;
private RootWorkflowInboundCallsInterceptor rootWorkflowInvoker;
private WorkflowInboundCallsInterceptor workflowInvoker;
// don't pass it down to other classes, it's a "cached" instance for internal usage only
private final DataConverter dataConverterWithWorkflowContext;
Expand All @@ -333,7 +335,8 @@ public POJOWorkflowImplementation(
@Override
public void initialize(Optional<Payloads> input) {
SyncWorkflowContext workflowContext = WorkflowInternal.getRootWorkflowContext();
workflowInvoker = new RootWorkflowInboundCallsInterceptor(workflowContext, input);
rootWorkflowInvoker = new RootWorkflowInboundCallsInterceptor(workflowContext, input);
workflowInvoker = rootWorkflowInvoker;
for (WorkerInterceptor workerInterceptor : workerInterceptors) {
workflowInvoker = workerInterceptor.interceptWorkflow(workflowInvoker);
}
Expand All @@ -357,6 +360,13 @@ public Optional<Payloads> execute(Header header, Optional<Payloads> input)
return dataConverterWithWorkflowContext.toPayloads(result.getResult());
}

@Nullable
@Override
public Object getInstance() {
Objects.requireNonNull(rootWorkflowInvoker, "getInstance called before initialize.");
return rootWorkflowInvoker.getInstance();
}

private class RootWorkflowInboundCallsInterceptor
extends BaseRootWorkflowInboundCallsInterceptor {
private Object workflow;
Expand All @@ -368,6 +378,10 @@ public RootWorkflowInboundCallsInterceptor(
this.input = input;
}

public Object getInstance() {
return workflow;
}

@Override
public void init(WorkflowOutboundCallsInterceptor outboundCalls) {
super.init(outboundCalls);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public SyncWorkflow(
new SyncWorkflowContext(
namespace,
workflowExecution,
workflow,
signalDispatcher,
queryDispatcher,
updateDispatcher,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ final class SyncWorkflowContext implements WorkflowContext, WorkflowOutboundCall

private final String namespace;
private final WorkflowExecution workflowExecution;
private final SyncWorkflowDefinition workflowDefinition;
private final WorkflowImplementationOptions workflowImplementationOptions;
private final DataConverter dataConverter;
// to be used in this class, should not be passed down. Pass the original #dataConverter instead
Expand All @@ -125,15 +126,16 @@ final class SyncWorkflowContext implements WorkflowContext, WorkflowOutboundCall
private Map<String, NexusServiceOptions> nexusServiceOptionsMap;
private boolean readOnly = false;
private final WorkflowThreadLocal<UpdateInfo> currentUpdateInfo = new WorkflowThreadLocal<>();
// Map of all running update handlers. Key is the update Id of the update request.
// Map of all running update handlers. Key is the update ID of the update request.
private Map<String, UpdateHandlerInfo> runningUpdateHandlers = new HashMap<>();
// Map of all running signal handlers. Key is the event Id of the signal event.
// Map of all running signal handlers. Key is the event ID of the signal event.
private Map<Long, SignalHandlerInfo> runningSignalHandlers = new HashMap<>();
@Nullable private String currentDetails;

public SyncWorkflowContext(
@Nonnull String namespace,
@Nonnull WorkflowExecution workflowExecution,
@Nullable SyncWorkflowDefinition workflowDefinition,
SignalDispatcher signalDispatcher,
QueryDispatcher queryDispatcher,
UpdateDispatcher updateDispatcher,
Expand All @@ -142,6 +144,7 @@ public SyncWorkflowContext(
List<ContextPropagator> contextPropagators) {
this.namespace = namespace;
this.workflowExecution = workflowExecution;
this.workflowDefinition = workflowDefinition;
this.dataConverter = dataConverter;
this.dataConverterWithCurrentWorkflowContext =
dataConverter.withContext(
Expand Down Expand Up @@ -1492,6 +1495,11 @@ public void setCurrentDetails(String details) {
currentDetails = details;
}

@Nullable
public Object getInstance() {
return workflowDefinition.getInstance();
}

@Nullable
public String getCurrentDetails() {
return currentDetails;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@
import io.temporal.api.common.v1.Payloads;
import io.temporal.common.interceptors.Header;
import java.util.Optional;
import javax.annotation.Nullable;

/** Workflow wrapper used by the workflow thread to start a workflow */
interface SyncWorkflowDefinition {

/** Always called first. */
void initialize(Optional<Payloads> input);

/**
* Returns the workflow instance that is executing this code. Must be called after {@link
* #initialize(Optional)}.
*/
@Nullable
Object getInstance();

Optional<Payloads> execute(Header header, Optional<Payloads> input);
}
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,11 @@ public static String getCurrentDetails() {
return getRootWorkflowContext().getCurrentDetails();
}

@Nullable
public static Object getInstance() {
return getRootWorkflowContext().getInstance();
}

static WorkflowOutboundCallsInterceptor getWorkflowOutboundInterceptor() {
return getRootWorkflowContext().getWorkflowOutboundInterceptor();
}
Expand Down
15 changes: 15 additions & 0 deletions temporal-sdk/src/main/java/io/temporal/workflow/Workflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.temporal.common.SearchAttributeUpdate;
import io.temporal.common.SearchAttributes;
import io.temporal.common.converter.DataConverter;
import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptor;
import io.temporal.failure.ActivityFailure;
import io.temporal.failure.CanceledFailure;
import io.temporal.failure.ChildWorkflowFailure;
Expand Down Expand Up @@ -1385,6 +1386,20 @@ public static String getCurrentDetails() {
return WorkflowInternal.getCurrentDetails();
}

/**
* Get the currently running workflow instance.
*
* @apiNote The instance is only available after it has been initialized. This function will
* return null if called before the workflow has been initialized. For example, this could
* happen if the function is called from a {@link WorkflowInit} constructor or {@link
* io.temporal.common.interceptors.WorkflowInboundCallsInterceptor#init(WorkflowOutboundCallsInterceptor)}.
*/
@Experimental
@Nullable
public static Object getInstance() {
return WorkflowInternal.getInstance();
}

/** Prohibit instantiation. */
private Workflow() {}
}
Loading

0 comments on commit b471e13

Please sign in to comment.