Skip to content

Commit

Permalink
Improve workflow decorator type hints with overload (#1635)
Browse files Browse the repository at this point in the history
Previously, the workflow decorator is hinted as always returning a WorkflowBase, which is not true when _workflow_function is None; similar to #1631, we propose using typing.overload to differentiate the return type of workflow based on the value of _workflow_function

Signed-off-by: Matthew Hoffman <[email protected]>
  • Loading branch information
ringohoffman authored and eapolinario committed May 16, 2023
1 parent fc065db commit ded15a8
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import get_args

Expand Down Expand Up @@ -650,9 +650,9 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver):

def __init__(
self,
workflow_function: Callable,
metadata: Optional[WorkflowMetadata],
default_metadata: Optional[WorkflowMetadataDefaults],
workflow_function: Callable[..., Any],
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
docs: Optional[Documentation] = None,
):
Expand Down Expand Up @@ -774,12 +774,32 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


@overload
def workflow(
_workflow_function=None,
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]:
...


@overload
def workflow(
_workflow_function: Callable[..., Any],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> PythonFunctionWorkflow:
...


def workflow(
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
docs: Optional[Documentation] = None,
) -> WorkflowBase:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -810,7 +830,7 @@ def workflow(
:param docs: Description entity for the workflow
"""

def wrapper(fn):
def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand All @@ -825,7 +845,7 @@ def wrapper(fn):
update_wrapper(workflow_instance, fn)
return workflow_instance

if _workflow_function:
if _workflow_function is not None:
return wrapper(_workflow_function)
else:
return wrapper
Expand Down

0 comments on commit ded15a8

Please sign in to comment.