Skip to content

Commit

Permalink
Add timer interface to Tasks (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmrdavid authored Sep 24, 2021
1 parent 67f0a87 commit 1aac4bc
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 7 deletions.
9 changes: 5 additions & 4 deletions azure/durable_functions/models/DurableOrchestrationContext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from azure.durable_functions.models.actions.SignalEntityAction import SignalEntityAction
from azure.durable_functions.models.actions.CallEntityAction import CallEntityAction
from azure.durable_functions.models.Task import TaskBase
from azure.durable_functions.models.Task import TaskBase, TimerTask
from azure.durable_functions.models.actions.CallHttpAction import CallHttpAction
from azure.durable_functions.models.DurableHttpRequest import DurableHttpRequest
from azure.durable_functions.models.actions.CallSubOrchestratorWithRetryAction import \
Expand Down Expand Up @@ -100,7 +100,8 @@ def from_json(cls, json_string: str):
def _generate_task(self, action: Action,
retry_options: Optional[RetryOptions] = None,
id_: Optional[Union[int, str]] = None,
parent: Optional[TaskBase] = None) -> Union[AtomicTask, RetryAbleTask]:
parent: Optional[TaskBase] = None,
task_constructor=AtomicTask) -> Union[AtomicTask, RetryAbleTask, TimerTask]:
"""Generate an atomic or retryable Task based on an input.
Parameters
Expand All @@ -124,7 +125,7 @@ def _generate_task(self, action: Action,
action_payload = [action]
else:
action_payload = action
task = AtomicTask(id_, action_payload)
task = task_constructor(id_, action_payload)
task.parent = parent

# if task is retryable, provide the retryable wrapper class
Expand Down Expand Up @@ -517,7 +518,7 @@ def create_timer(self, fire_at: datetime.datetime) -> TaskBase:
A Durable Timer Task that schedules the timer to wake up the activity
"""
action = CreateTimerAction(fire_at)
task = self._generate_task(action)
task = self._generate_task(action, task_constructor=TimerTask)
return task

def wait_for_external_event(self, name: str) -> TaskBase:
Expand Down
51 changes: 50 additions & 1 deletion azure/durable_functions/models/Task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from azure.durable_functions.models.actions.Action import Action
from azure.durable_functions.models.actions.WhenAnyAction import WhenAnyAction
from azure.durable_functions.models.actions.WhenAllAction import WhenAllAction
from azure.durable_functions.models.actions.CreateTimerAction import CreateTimerAction

import enum
from typing import Any, List, Optional, Set, Type, Union
Expand Down Expand Up @@ -56,6 +57,14 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]):
self.action_repr: Union[List[Action], Action] = actions
self.is_played = False

@property
def is_completed(self) -> bool:
"""Get indicator of whether the task completed.
Note that completion is not equivalent to success.
"""
return not(self.state is TaskState.RUNNING)

def set_is_played(self, is_played: bool):
"""Set the is_played flag for the Task.
Expand Down Expand Up @@ -208,7 +217,47 @@ def try_set_value(self, child: TaskBase):
class AtomicTask(TaskBase):
"""A Task with no subtasks."""

pass
def _get_action(self) -> Action:
action: Action
if isinstance(self.action_repr, list):
action = self.action_repr[0]
else:
action = self.action_repr
return action


class TimerTask(AtomicTask):
"""A Timer Task."""

def __init__(self, id_: Union[int, str], action: CreateTimerAction):
super().__init__(id_, action)
self.action_repr: Union[List[CreateTimerAction], CreateTimerAction]

@property
def is_cancelled(self) -> bool:
"""Check if the Timer is cancelled.
Returns
-------
bool
Returns whether a timer has been cancelled or not
"""
action: CreateTimerAction = self._get_action()
return action.is_cancelled

def cancel(self):
"""Cancel a timer.
Raises
------
ValueError
Raises an error if the task is already completed and an attempt is made to cancel it
"""
if not self.is_completed:
action: CreateTimerAction = self._get_action()
action.is_cancelled = True
else:
raise ValueError("Cannot cancel a completed task.")


class WhenAllTask(CompoundTask):
Expand Down
37 changes: 36 additions & 1 deletion tests/orchestrator/test_create_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def generator_function(context):
yield context.create_timer(fire_at)
return "Done!"

def generator_function_timer_can_be_cancelled(context):
time_limit1 = context.current_utc_datetime + timedelta(minutes=5)
timer_task1 = context.create_timer(time_limit1)

time_limit2 = context.current_utc_datetime + timedelta(minutes=10)
timer_task2 = context.create_timer(time_limit2)

winner = yield context.task_any([timer_task1, timer_task2])
if winner == timer_task1:
timer_task2.cancel()
return "Done!"
else:
raise Exception("timer task 1 should complete before timer task 2")

def add_timer_action(state: OrchestratorState, fire_at: datetime):
action = CreateTimerAction(fire_at=fire_at)
state._actions.append([action])
Expand Down Expand Up @@ -64,4 +78,25 @@ def test_timers_comparison_with_relaxed_precision():
#assert_valid_schema(result)
# TODO: getting the following error when validating the schema
# "Additional properties are not allowed ('fireAt', 'isCanceled' were unexpected)">
assert_orchestration_state_equals(expected, result)
assert_orchestration_state_equals(expected, result)

def test_timers_can_be_cancelled():

context_builder = ContextBuilder("test_timers_can_be_cancelled")
fire_at1 = context_builder.current_datetime + timedelta(minutes=5)
fire_at2 = context_builder.current_datetime + timedelta(minutes=10)
add_timer_fired_events(context_builder, 0, str(fire_at1))
add_timer_fired_events(context_builder, 1, str(fire_at2))

result = get_orchestration_state_result(
context_builder, generator_function_timer_can_be_cancelled)

expected_state = base_expected_state(output='Done!')
expected_state._actions.append(
[CreateTimerAction(fire_at=fire_at1), CreateTimerAction(fire_at=fire_at2, is_cancelled=True)])

expected_state._is_done = True
expected = expected_state.to_json()

assert_orchestration_state_equals(expected, result)
assert result["actions"][0][1]["isCanceled"]
6 changes: 5 additions & 1 deletion tests/orchestrator/test_sequential_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,14 @@ def generator_function_new_guid(context):
outputs.append(str(output3))
return outputs


def base_expected_state(output=None, replay_schema: ReplaySchema = ReplaySchema.V1) -> OrchestratorState:
return OrchestratorState(is_done=False, actions=[], output=output, replay_schema=replay_schema)

def add_timer_fired_events(context_builder: ContextBuilder, id_: int, timestamp: str):
fire_at: str = context_builder.add_timer_created_event(id_, timestamp)
context_builder.add_orchestrator_completed_event()
context_builder.add_orchestrator_started_event()
context_builder.add_timer_fired_event(id_=id_, fire_at=fire_at)

def add_hello_action(state: OrchestratorState, input_: str):
action = CallActivityAction(function_name='Hello', input_=input_)
Expand Down

0 comments on commit 1aac4bc

Please sign in to comment.