From ded15a8733ac835cc42bbe62839e6ef6382c8ddd Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 12 May 2023 13:07:09 -0700 Subject: [PATCH] Improve workflow decorator type hints with overload (#1635) Previously, the workflow decorator is hinted as always returning a WorkflowBase, which is not true when _workflow_function is None; similar to #1631, we propose using typing.overload to differentiate the return type of workflow based on the value of _workflow_function Signed-off-by: Matthew Hoffman --- flytekit/core/workflow.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index b5ca155f1a..39712dc326 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload from typing_extensions import get_args @@ -650,9 +650,9 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, - workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], + workflow_function: Callable[..., Any], + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, docstring: Optional[Docstring] = None, docs: Optional[Documentation] = None, ): @@ -774,12 +774,32 @@ def execute(self, **kwargs): return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) +@overload def workflow( - _workflow_function=None, + _workflow_function: None = ..., + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]: + ... + + +@overload +def workflow( + _workflow_function: Callable[..., Any], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> PythonFunctionWorkflow: + ... + + +def workflow( + _workflow_function: Optional[Callable[..., Any]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> WorkflowBase: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -810,7 +830,7 @@ def workflow( :param docs: Description entity for the workflow """ - def wrapper(fn): + def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -825,7 +845,7 @@ def wrapper(fn): update_wrapper(workflow_instance, fn) return workflow_instance - if _workflow_function: + if _workflow_function is not None: return wrapper(_workflow_function) else: return wrapper