Skip to content

Commit

Permalink
Run flytekit tasks on remote with local default values passed correct…
Browse files Browse the repository at this point in the history
…ly (#2525)

* Run flytekit tasks on remote with local default values passed correctly

```python
@task
def default_inputs(i: int = 0, f: float = 10.0, e: Color = Color.RED, b: bool = True, j: typing.Optional[int] = None):
  print(i, f, e, b, j)
```

Running this on remote will now work correctly

```bash
pyflyte run --remote exhaustive.py default_inputs
```

Signed-off-by: Ketan Umare <[email protected]>

* fixing lint

Signed-off-by: Ketan Umare <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent d8898bb commit c0613dd
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 4 deletions.
18 changes: 15 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder:
return ctx_builder.with_file_access(file_access)


def is_optional(_type):
"""
Checks if the given type is Optional Type
"""
return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type)


def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]):
"""
Returns a function that is used to implement WorkflowCommand and execute a flyte workflow.
Expand All @@ -526,8 +533,13 @@ def _run(*args, **kwargs):
click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan")
try:
inputs = {}
for input_name, _ in entity.python_interface.inputs.items():
for input_name, v in entity.python_interface.inputs_with_defaults.items():
processed_click_value = kwargs.get(input_name)
optional_v = False
if processed_click_value is None and isinstance(v, typing.Tuple):
optional_v = is_optional(v[0])
if len(v) == 2:
processed_click_value = v[1]
if isinstance(processed_click_value, ArtifactQuery):
if run_level_params.is_remote:
click.secho(
Expand All @@ -542,7 +554,7 @@ def _run(*args, **kwargs):
raise click.UsageError(
f"Default for '{input_name}' is a query, which must be specified when running locally."
)
if processed_click_value is not None:
if processed_click_value is not None or optional_v:
inputs[input_name] = processed_click_value

if not run_level_params.is_remote:
Expand Down Expand Up @@ -780,7 +792,7 @@ def _create_command(
ctx: click.Context,
entity_name: str,
run_level_params: RunLevelParams,
loaded_entity: typing.Any,
loaded_entity: [PythonTask, WorkflowBase],
is_workflow: bool,
):
"""
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str:
if isinstance(v, tuple):
if v[1]:
if v[1] is not None:
return f"{k}: {v[0]}={v[1]}"
return f"{k}: {v[0]}"
return f"{k}: {v}"
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/default_arguments/task_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import enum

from flytekit import task


class Color(enum.Enum):
RED = 'red'
GREEN = 'green'
BLUE = 'blue'


@task
def foo(i: int = 0, j: str = "Hello", c: Color = Color.RED):
print(i, j, c)
20 changes: 20 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,23 @@ def test_envvar_local_execution(envs, envs_argument, expected_output, workflow_f
)
output = result.stdout.strip().split("\n")[-1].strip()
assert output == expected_output


@pytest.mark.parametrize(
"task_path",
[("task_defaults.py")],
)
def test_list_default_arguments(task_path):
runner = CliRunner()
dir_name = os.path.dirname(os.path.realpath(__file__))
result = runner.invoke(
pyflyte.main,
[
"run",
os.path.join(dir_name, "default_arguments", task_path),
"foo",
],
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.stdout == "Running Execution on local.\n0 Hello Color.RED\n\n"
20 changes: 20 additions & 0 deletions tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,23 @@ def t1(i: str):
@task(node_dependency_hints=[t1])
def t2(i: str):
pass


def test_default_inputs():
@task
def foo(x: int = 0, y: str = "Hello") -> int:
return x

assert foo.python_interface.default_inputs_as_kwargs == {"x": 0, "y": "Hello"}

@task
def foo2(x: int, y: str = "Hello") -> int:
return x

assert foo2.python_interface.inputs_with_defaults == {"x": (int, None), "y": (str, "Hello")}

@task
def foo3(x: int, y: str) -> int:
return x

assert foo3.python_interface.inputs_with_defaults == {"x": (int, None), "y": (str, None)}

0 comments on commit c0613dd

Please sign in to comment.