From ccb8a632e941b95e4bb823dd50113f89f792ccbe Mon Sep 17 00:00:00 2001 From: Chris Gillum Date: Mon, 6 Jan 2025 08:35:13 -0800 Subject: [PATCH] Update version to 0.2b1, require Python 3.9+, and enhance GitHub Actions workflow (#1) (#35) - Bump version in `pyproject.toml` to 0.2b1 and update Python requirement to >=3.9. - Add `protobuf` dependency in `requirements.txt`. - Update GitHub Actions workflow to support Python versions 3.9 to 3.13 and upgrade action versions. - Refactor type hints in various files to use `Optional` and `list` instead of `Union` and `List`. - Improve handling of custom status in orchestration context and related functions. - Fix purge implementation to pass required parameters. --- .github/workflows/pr-validation.yml | 19 +++++++-- .vscode/settings.json | 5 ++- durabletask/client.py | 40 +++++++++---------- durabletask/internal/grpc_interceptor.py | 3 +- durabletask/internal/helpers.py | 32 +++++++-------- durabletask/internal/shared.py | 15 ++++--- durabletask/task.py | 30 ++++++++------ durabletask/worker.py | 50 ++++++++++++------------ examples/fanout_fanin.py | 7 ++-- pyproject.toml | 4 +- requirements.txt | 1 + tests/test_activity_executor.py | 4 +- tests/test_orchestration_e2e.py | 2 +- tests/test_orchestration_executor.py | 3 +- 14 files changed, 118 insertions(+), 97 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 4c09e6b..70ff470 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -16,12 +16,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -35,3 +35,16 @@ jobs: - name: Pytest unit tests run: | pytest -m "not e2e" --verbose + + # Sidecar for running e2e tests requires Go SDK + - name: Install Go SDK + uses: actions/setup-go@v5 + with: + go-version: 'stable' + + # Install and run the durabletask-go sidecar for running e2e tests + - name: Pytest e2e tests + run: | + go install github.com/microsoft/durabletask-go@main + durabletask-go --port 4001 & + pytest -m "e2e" --verbose diff --git a/.vscode/settings.json b/.vscode/settings.json index d737b0b..1c929ac 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,7 @@ "editor.defaultFormatter": "ms-python.autopep8", "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true, + "source.organizeImports": "explicit" }, "editor.rulers": [ 119 @@ -29,5 +29,6 @@ "coverage.xml", "jacoco.xml", "coverage.cobertura.xml" - ] + ], + "makefile.configureOnOpen": false } \ No newline at end of file diff --git a/durabletask/client.py b/durabletask/client.py index 82f920a..31953ae 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, List, Tuple, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 @@ -42,10 +42,10 @@ class OrchestrationState: runtime_status: OrchestrationStatus created_at: datetime last_updated_at: datetime - serialized_input: Union[str, None] - serialized_output: Union[str, None] - serialized_custom_status: Union[str, None] - failure_details: Union[task.FailureDetails, None] + serialized_input: Optional[str] + serialized_output: Optional[str] + serialized_custom_status: Optional[str] + failure_details: Optional[task.FailureDetails] def raise_if_failed(self): if self.failure_details is not None: @@ -64,7 +64,7 @@ def failure_details(self): return self._failure_details -def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Union[OrchestrationState, None]: +def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Optional[OrchestrationState]: if not res.exists: return None @@ -92,20 +92,20 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Un class TaskHubGrpcClient: def __init__(self, *, - host_address: Union[str, None] = None, - metadata: Union[List[Tuple[str, str]], None] = None, - log_handler = None, - log_formatter: Union[logging.Formatter, None] = None, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False): channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, - input: Union[TInput, None] = None, - instance_id: Union[str, None] = None, - start_at: Union[datetime, None] = None, - reuse_id_policy: Union[pb.OrchestrationIdReusePolicy, None] = None) -> str: + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str: name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) @@ -122,14 +122,14 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu res: pb.CreateInstanceResponse = self._stub.StartInstance(req) return res.instanceId - def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Union[OrchestrationState, None]: + def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) res: pb.GetInstanceResponse = self._stub.GetInstance(req) return new_orchestration_state(req.instanceId, res) def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Union[OrchestrationState, None]: + timeout: int = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") @@ -144,7 +144,7 @@ def wait_for_orchestration_start(self, instance_id: str, *, def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Union[OrchestrationState, None]: + timeout: int = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") @@ -170,7 +170,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *, raise def raise_orchestration_event(self, instance_id: str, event_name: str, *, - data: Union[Any, None] = None): + data: Optional[Any] = None): req = pb.RaiseEventRequest( instanceId=instance_id, name=event_name, @@ -180,7 +180,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, self._stub.RaiseEvent(req) def terminate_orchestration(self, instance_id: str, *, - output: Union[Any, None] = None, + output: Optional[Any] = None, recursive: bool = True): req = pb.TerminateRequest( instanceId=instance_id, @@ -203,4 +203,4 @@ def resume_orchestration(self, instance_id: str): def purge_orchestration(self, instance_id: str, recursive: bool = True): req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - self._stub.PurgeInstances() + self._stub.PurgeInstances(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 5b12ace..738fca9 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. from collections import namedtuple -from typing import List, Tuple import grpc @@ -26,7 +25,7 @@ class DefaultClientInterceptorImpl ( StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" - def __init__(self, metadata: List[Tuple[str, str]]): + def __init__(self, metadata: list[tuple[str, str]]): super().__init__() self._metadata = metadata diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index c7354e5..6b36586 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -3,7 +3,7 @@ import traceback from datetime import datetime -from typing import List, Union +from typing import Optional from google.protobuf import timestamp_pb2, wrappers_pb2 @@ -12,14 +12,14 @@ # TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere -def new_orchestrator_started_event(timestamp: Union[datetime, None] = None) -> pb.HistoryEvent: +def new_orchestrator_started_event(timestamp: Optional[datetime] = None) -> pb.HistoryEvent: ts = timestamp_pb2.Timestamp() if timestamp is not None: ts.FromDatetime(timestamp) return pb.HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent()) -def new_execution_started_event(name: str, instance_id: str, encoded_input: Union[str, None] = None) -> pb.HistoryEvent: +def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -49,7 +49,7 @@ def new_timer_fired_event(timer_id: int, fire_at: datetime) -> pb.HistoryEvent: ) -def new_task_scheduled_event(event_id: int, name: str, encoded_input: Union[str, None] = None) -> pb.HistoryEvent: +def new_task_scheduled_event(event_id: int, name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=event_id, timestamp=timestamp_pb2.Timestamp(), @@ -57,7 +57,7 @@ def new_task_scheduled_event(event_id: int, name: str, encoded_input: Union[str, ) -def new_task_completed_event(event_id: int, encoded_output: Union[str, None] = None) -> pb.HistoryEvent: +def new_task_completed_event(event_id: int, encoded_output: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -77,7 +77,7 @@ def new_sub_orchestration_created_event( event_id: int, name: str, instance_id: str, - encoded_input: Union[str, None] = None) -> pb.HistoryEvent: + encoded_input: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=event_id, timestamp=timestamp_pb2.Timestamp(), @@ -88,7 +88,7 @@ def new_sub_orchestration_created_event( ) -def new_sub_orchestration_completed_event(event_id: int, encoded_output: Union[str, None] = None) -> pb.HistoryEvent: +def new_sub_orchestration_completed_event(event_id: int, encoded_output: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -116,7 +116,7 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails: ) -def new_event_raised_event(name: str, encoded_input: Union[str, None] = None) -> pb.HistoryEvent: +def new_event_raised_event(name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -140,7 +140,7 @@ def new_resume_event() -> pb.HistoryEvent: ) -def new_terminated_event(*, encoded_output: Union[str, None] = None) -> pb.HistoryEvent: +def new_terminated_event(*, encoded_output: Optional[str] = None) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -150,7 +150,7 @@ def new_terminated_event(*, encoded_output: Union[str, None] = None) -> pb.Histo ) -def get_string_value(val: Union[str, None]) -> Union[wrappers_pb2.StringValue, None]: +def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: if val is None: return None else: @@ -160,9 +160,9 @@ def get_string_value(val: Union[str, None]) -> Union[wrappers_pb2.StringValue, N def new_complete_orchestration_action( id: int, status: pb.OrchestrationStatus, - result: Union[str, None] = None, - failure_details: Union[pb.TaskFailureDetails, None] = None, - carryover_events: Union[List[pb.HistoryEvent], None] = None) -> pb.OrchestratorAction: + result: Optional[str] = None, + failure_details: Optional[pb.TaskFailureDetails] = None, + carryover_events: Optional[list[pb.HistoryEvent]] = None) -> pb.OrchestratorAction: completeOrchestrationAction = pb.CompleteOrchestrationAction( orchestrationStatus=status, result=get_string_value(result), @@ -178,7 +178,7 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp)) -def new_schedule_task_action(id: int, name: str, encoded_input: Union[str, None]) -> pb.OrchestratorAction: +def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str]) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction( name=name, input=get_string_value(encoded_input) @@ -194,8 +194,8 @@ def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: def new_create_sub_orchestration_action( id: int, name: str, - instance_id: Union[str, None], - encoded_input: Union[str, None]) -> pb.OrchestratorAction: + instance_id: Optional[str], + encoded_input: Optional[str]) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( name=name, instanceId=instance_id, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 80c3d56..400529a 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -5,7 +5,7 @@ import json import logging from types import SimpleNamespace -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Optional import grpc @@ -20,7 +20,10 @@ def get_default_host_address() -> str: return "localhost:4001" -def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[str, str]], None], secure_channel: bool = False) -> grpc.Channel: +def get_grpc_channel( + host_address: Optional[str], + metadata: Optional[list[tuple[str, str]]], + secure_channel: bool = False) -> grpc.Channel: if host_address is None: host_address = get_default_host_address() @@ -36,8 +39,8 @@ def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[ def get_logger( name_suffix: str, - log_handler: Union[logging.Handler, None] = None, - log_formatter: Union[logging.Formatter, None] = None) -> logging.Logger: + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None) -> logging.Logger: logger = logging.Logger(f"durabletask-{name_suffix}") # Add a default log handler if none is provided @@ -78,7 +81,7 @@ def default(self, obj): if dataclasses.is_dataclass(obj): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver - d = dataclasses.asdict(obj) + d = dataclasses.asdict(obj) # type: ignore d[AUTO_SERIALIZED] = True return d elif isinstance(obj, SimpleNamespace): @@ -94,7 +97,7 @@ class InternalJSONDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): super().__init__(object_hook=self.dict_to_object, *args, **kwargs) - def dict_to_object(self, d: Dict[str, Any]): + def dict_to_object(self, d: dict[str, Any]): # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace if d.pop(AUTO_SERIALIZED, False): return SimpleNamespace(**d) diff --git a/durabletask/task.py b/durabletask/task.py index a9f85de..a40602b 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,8 +7,7 @@ import math from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import (Any, Callable, Generator, Generic, List, Optional, TypeVar, - Union) +from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -72,8 +71,13 @@ def is_replaying(self) -> bool: pass @abstractmethod - def set_custom_status(self, custom_status: str) -> None: - """Set the custom status. + def set_custom_status(self, custom_status: Any) -> None: + """Set the orchestration instance's custom status. + + Parameters + ---------- + custom_status: Any + A JSON-serializable custom status value to set. """ pass @@ -254,9 +258,9 @@ def get_exception(self) -> TaskFailedError: class CompositeTask(Task[T]): """A task that is composed of other tasks.""" - _tasks: List[Task] + _tasks: list[Task] - def __init__(self, tasks: List[Task]): + def __init__(self, tasks: list[Task]): super().__init__() self._tasks = tasks self._completed_tasks = 0 @@ -266,17 +270,17 @@ def __init__(self, tasks: List[Task]): if task.is_complete: self.on_child_completed(task) - def get_tasks(self) -> List[Task]: + def get_tasks(self) -> list[Task]: return self._tasks @abstractmethod def on_child_completed(self, task: Task[T]): pass -class WhenAllTask(CompositeTask[List[T]]): +class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" - def __init__(self, tasks: List[Task[T]]): + def __init__(self, tasks: list[Task[T]]): super().__init__(tasks) self._completed_tasks = 0 self._failed_tasks = 0 @@ -340,7 +344,7 @@ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, def increment_attempt_count(self) -> None: self._attempt_count += 1 - def compute_next_delay(self) -> Union[timedelta, None]: + def compute_next_delay(self) -> Optional[timedelta]: if self._attempt_count >= self._retry_policy.max_number_of_attempts: return None @@ -375,7 +379,7 @@ def set_retryable_parent(self, retryable_task: RetryableTask): class WhenAnyTask(CompositeTask[Task]): """A task that completes when any of its child tasks complete.""" - def __init__(self, tasks: List[Task]): + def __init__(self, tasks: list[Task]): super().__init__(tasks) def on_child_completed(self, task: Task): @@ -385,12 +389,12 @@ def on_child_completed(self, task: Task): self._result = task -def when_all(tasks: List[Task[T]]) -> WhenAllTask[T]: +def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]: """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail.""" return WhenAllTask(tasks) -def when_any(tasks: List[Task]) -> WhenAnyTask: +def when_any(tasks: list[Task]) -> WhenAnyTask: """Returns a task that completes when any of the provided tasks complete or fail.""" return WhenAnyTask(tasks) diff --git a/durabletask/worker.py b/durabletask/worker.py index bcc1a30..75e2e37 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -6,8 +6,7 @@ from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType -from typing import (Any, Dict, Generator, List, Optional, Sequence, Tuple, - TypeVar, Union) +from typing import Any, Generator, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import empty_pb2, wrappers_pb2 @@ -25,8 +24,8 @@ class _Registry: - orchestrators: Dict[str, task.Orchestrator] - activities: Dict[str, task.Activity] + orchestrators: dict[str, task.Orchestrator] + activities: dict[str, task.Activity] def __init__(self): self.orchestrators = {} @@ -86,7 +85,7 @@ class TaskHubGrpcWorker: def __init__(self, *, host_address: Optional[str] = None, - metadata: Optional[List[Tuple[str, str]]] = None, + metadata: Optional[list[tuple[str, str]]] = None, log_handler=None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False): @@ -140,7 +139,7 @@ def run_loop(): # The stream blocks until either a work item is received or the stream is canceled # by another thread (see the stop() method). - for work_item in self._response_stream: + for work_item in self._response_stream: # type: ignore request_type = work_item.WhichOneof('request') self._logger.debug(f'Received "{request_type}" work item') if work_item.HasField('orchestratorRequest'): @@ -189,7 +188,10 @@ def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHub try: executor = _OrchestrationExecutor(self._registry, self._logger) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) - res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=result.actions, customStatus=wrappers_pb2.StringValue(value=result.custom_status)) + res = pb.OrchestratorResponse( + instanceId=req.instanceId, + actions=result.actions, + customStatus=pbh.get_string_value(result.encoded_custom_status)) except Exception as ex: self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}") failure_details = pbh.new_failure_details(ex) @@ -232,17 +234,17 @@ def __init__(self, instance_id: str): self._is_replaying = True self._is_complete = False self._result = None - self._pending_actions: Dict[int, pb.OrchestratorAction] = {} - self._pending_tasks: Dict[int, task.CompletableTask] = {} + self._pending_actions: dict[int, pb.OrchestratorAction] = {} + self._pending_tasks: dict[int, task.CompletableTask] = {} self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id self._completion_status: Optional[pb.OrchestrationStatus] = None - self._received_events: Dict[str, List[Any]] = {} - self._pending_events: Dict[str, List[task.CompletableTask]] = {} + self._received_events: dict[str, list[Any]] = {} + self._pending_events: dict[str, list[task.CompletableTask]] = {} self._new_input: Optional[Any] = None self._save_events = False - self._custom_status: str = "" + self._encoded_custom_status: Optional[str] = None def run(self, generator: Generator[task.Task, Any, Any]): self._generator = generator @@ -314,10 +316,10 @@ def set_continued_as_new(self, new_input: Any, save_events: bool): self._new_input = new_input self._save_events = save_events - def get_actions(self) -> List[pb.OrchestratorAction]: + def get_actions(self) -> list[pb.OrchestratorAction]: if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: # When continuing-as-new, we only return a single completion action. - carryover_events: Optional[List[pb.HistoryEvent]] = None + carryover_events: Optional[list[pb.HistoryEvent]] = None if self._save_events: carryover_events = [] # We need to save the current set of pending events so that they can be @@ -356,8 +358,8 @@ def is_replaying(self) -> bool: def current_utc_datetime(self, value: datetime): self._current_utc_datetime = value - def set_custom_status(self, custom_status: str) -> None: - self._custom_status = custom_status + def set_custom_status(self, custom_status: Any) -> None: + self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) @@ -462,12 +464,12 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None: class ExecutionResults: - actions: List[pb.OrchestratorAction] - custom_status: str + actions: list[pb.OrchestratorAction] + encoded_custom_status: Optional[str] - def __init__(self, actions: List[pb.OrchestratorAction], custom_status: str): + def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]): self.actions = actions - self.custom_status = custom_status + self.encoded_custom_status = encoded_custom_status class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None @@ -476,7 +478,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger self._is_suspended = False - self._suspended_events: List[pb.HistoryEvent] = [] + self._suspended_events: list[pb.HistoryEvent] = [] def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> ExecutionResults: if not new_events: @@ -513,7 +515,7 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e actions = ctx.get_actions() if self._logger.level <= logging.DEBUG: self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}") - return ExecutionResults(actions=actions, custom_status=ctx._custom_status) + return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: if self._is_suspended and _is_suspendable(event): @@ -829,7 +831,7 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str: elif len(new_events) == 1: return f"[{new_events[0].WhichOneof('eventType')}]" else: - counts: Dict[str, int] = {} + counts: dict[str, int] = {} for event in new_events: event_type = event.WhichOneof('eventType') counts[event_type] = counts.get(event_type, 0) + 1 @@ -843,7 +845,7 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str: elif len(new_actions) == 1: return f"[{new_actions[0].WhichOneof('orchestratorActionType')}]" else: - counts: Dict[str, int] = {} + counts: dict[str, int] = {} for action in new_actions: action_type = action.WhichOneof('orchestratorActionType') counts[action_type] = counts.get(action_type, 0) + 1 diff --git a/examples/fanout_fanin.py b/examples/fanout_fanin.py index 3e054df..c53744f 100644 --- a/examples/fanout_fanin.py +++ b/examples/fanout_fanin.py @@ -3,12 +3,11 @@ to complete, and prints an aggregate summary of the outputs.""" import random import time -from typing import List from durabletask import client, task, worker -def get_work_items(ctx: task.ActivityContext, _) -> List[str]: +def get_work_items(ctx: task.ActivityContext, _) -> list[str]: """Activity function that returns a list of work items""" # return a random number of work items count = random.randint(2, 10) @@ -32,11 +31,11 @@ def orchestrator(ctx: task.OrchestrationContext, _): activity functions in parallel, waits for them all to complete, and prints an aggregate summary of the outputs""" - work_items: List[str] = yield ctx.call_activity(get_work_items) + work_items: list[str] = yield ctx.call_activity(get_work_items) # execute the work-items in parallel and wait for them all to return tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items] - results: List[int] = yield task.when_all(tasks) + results: list[int] = yield task.when_all(tasks) # return an aggregate summary of the results return { diff --git a/pyproject.toml b/pyproject.toml index d57957d..577824b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask" -version = "0.1.1-alpha.1" +version = "0.2b1" description = "A Durable Task Client SDK for Python" keywords = [ "durable", @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", ] -requires-python = ">=3.8" +requires-python = ">=3.9" license = {file = "LICENSE"} readme = "README.md" dependencies = [ diff --git a/requirements.txt b/requirements.txt index 641cee7..af76d88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ autopep8 grpcio grpcio-tools +protobuf pytest pytest-cov \ No newline at end of file diff --git a/tests/test_activity_executor.py b/tests/test_activity_executor.py index b9a4bd4..bfc8eaf 100644 --- a/tests/test_activity_executor.py +++ b/tests/test_activity_executor.py @@ -3,7 +3,7 @@ import json import logging -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple from durabletask import task, worker @@ -40,7 +40,7 @@ def test_activity(ctx: task.ActivityContext, _): executor, _ = _get_activity_executor(test_activity) - caught_exception: Union[Exception, None] = None + caught_exception: Optional[Exception] = None try: executor.execute(TEST_INSTANCE_ID, "Bogus", TEST_TASK_ID, None) except Exception as ex: diff --git a/tests/test_orchestration_e2e.py b/tests/test_orchestration_e2e.py index 1cfc520..d3d7f0b 100644 --- a/tests/test_orchestration_e2e.py +++ b/tests/test_orchestration_e2e.py @@ -466,4 +466,4 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.runtime_status == client.OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None - assert state.serialized_custom_status is "\"foobaz\"" + assert state.serialized_custom_status == "\"foobaz\"" diff --git a/tests/test_orchestration_executor.py b/tests/test_orchestration_executor.py index 95eab0b..cb77c81 100644 --- a/tests/test_orchestration_executor.py +++ b/tests/test_orchestration_executor.py @@ -4,7 +4,6 @@ import json import logging from datetime import datetime, timedelta -from typing import List import pytest @@ -1184,7 +1183,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert str(ex) in complete_action.failureDetails.errorMessage -def get_and_validate_single_complete_orchestration_action(actions: List[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: +def get_and_validate_single_complete_orchestration_action(actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: assert len(actions) == 1 assert type(actions[0]) is pb.OrchestratorAction assert actions[0].HasField("completeOrchestration")