diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 323988b340..e99f114818 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -509,6 +509,13 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder: return ctx_builder.with_file_access(file_access) +def is_optional(_type): + """ + Checks if the given type is Optional Type + """ + return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type) + + def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): """ Returns a function that is used to implement WorkflowCommand and execute a flyte workflow. @@ -526,8 +533,13 @@ def _run(*args, **kwargs): click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan") try: inputs = {} - for input_name, _ in entity.python_interface.inputs.items(): + for input_name, v in entity.python_interface.inputs_with_defaults.items(): processed_click_value = kwargs.get(input_name) + optional_v = False + if processed_click_value is None and isinstance(v, typing.Tuple): + optional_v = is_optional(v[0]) + if len(v) == 2: + processed_click_value = v[1] if isinstance(processed_click_value, ArtifactQuery): if run_level_params.is_remote: click.secho( @@ -542,7 +554,7 @@ def _run(*args, **kwargs): raise click.UsageError( f"Default for '{input_name}' is a query, which must be specified when running locally." ) - if processed_click_value is not None: + if processed_click_value is not None or optional_v: inputs[input_name] = processed_click_value if not run_level_params.is_remote: @@ -780,7 +792,7 @@ def _create_command( ctx: click.Context, entity_name: str, run_level_params: RunLevelParams, - loaded_entity: typing.Any, + loaded_entity: [PythonTask, WorkflowBase], is_workflow: bool, ): """ diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 13b6af2d4b..7f0af5f8e6 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -25,7 +25,7 @@ def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: if isinstance(v, tuple): - if v[1]: + if v[1] is not None: return f"{k}: {v[0]}={v[1]}" return f"{k}: {v[0]}" return f"{k}: {v}" diff --git a/tests/flytekit/unit/cli/pyflyte/default_arguments/task_defaults.py b/tests/flytekit/unit/cli/pyflyte/default_arguments/task_defaults.py new file mode 100644 index 0000000000..45c6c7d14c --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/default_arguments/task_defaults.py @@ -0,0 +1,14 @@ +import enum + +from flytekit import task + + +class Color(enum.Enum): + RED = 'red' + GREEN = 'green' + BLUE = 'blue' + + +@task +def foo(i: int = 0, j: str = "Hello", c: Color = Color.RED): + print(i, j, c) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 6957828743..3bb7697d47 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -516,3 +516,23 @@ def test_envvar_local_execution(envs, envs_argument, expected_output, workflow_f ) output = result.stdout.strip().split("\n")[-1].strip() assert output == expected_output + + +@pytest.mark.parametrize( + "task_path", + [("task_defaults.py")], +) +def test_list_default_arguments(task_path): + runner = CliRunner() + dir_name = os.path.dirname(os.path.realpath(__file__)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(dir_name, "default_arguments", task_path), + "foo", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert result.stdout == "Running Execution on local.\n0 Hello Color.RED\n\n" diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 8ba875b664..e775cb922a 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -264,3 +264,23 @@ def t1(i: str): @task(node_dependency_hints=[t1]) def t2(i: str): pass + + +def test_default_inputs(): + @task + def foo(x: int = 0, y: str = "Hello") -> int: + return x + + assert foo.python_interface.default_inputs_as_kwargs == {"x": 0, "y": "Hello"} + + @task + def foo2(x: int, y: str = "Hello") -> int: + return x + + assert foo2.python_interface.inputs_with_defaults == {"x": (int, None), "y": (str, "Hello")} + + @task + def foo3(x: int, y: str) -> int: + return x + + assert foo3.python_interface.inputs_with_defaults == {"x": (int, None), "y": (str, None)}