Skip to content

Commit

Permalink
Pyflyte run workflows correctly handles Optional[TYPE] = None (#1849)
Browse files Browse the repository at this point in the history
* fix pyflyte run handling of default None

Signed-off-by: Niels Bantilan <[email protected]>

* update tests

Signed-off-by: Niels Bantilan <[email protected]>

* remote print statements in tests

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Sep 26, 2023
1 parent 3f87154 commit ae23ddb
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
4 changes: 2 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
import typing
from dataclasses import dataclass, field, fields
from typing import cast
from typing import cast, get_args

import rich_click as click
from dataclasses_json import DataClassJsonMixin
Expand Down Expand Up @@ -691,7 +691,7 @@ def _create_command(
for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items():
literal_var = loaded_entity.interface.inputs.get(input_name)
python_type, default_val = input_type_val
required = default_val is None
required = type(None) not in get_args(python_type) and default_val is None
params.append(
to_click_option(
ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn, required
Expand Down
8 changes: 7 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.models import literals
from flytekit.models.literals import LiteralCollection, LiteralMap, Primitive, Union
from flytekit.models.literals import LiteralCollection, LiteralMap, Primitive, Union, Void
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteRemote
from flytekit.tools import script_mode
Expand Down Expand Up @@ -294,6 +294,12 @@ def convert_to_union(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
) -> Literal:
lt = self._literal_type

# handle case where Union type has NoneType and the value is None
has_none_type = any(v.simple == 0 for v in self._literal_type.union_type.variants)
if has_none_type and value is None:
return Literal(scalar=Scalar(none_type=Void()))

for i in range(len(self._literal_type.union_type.variants)):
variant = self._literal_type.union_type.variants[i]
python_type = get_args(self._python_type)[i]
Expand Down
42 changes: 35 additions & 7 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_pyflyte_run_cli():
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


Expand All @@ -133,7 +132,6 @@ def test_union_type1(input):
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


Expand Down Expand Up @@ -162,7 +160,6 @@ def test_union_type2(input):
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


Expand All @@ -185,9 +182,18 @@ def test_union_type_with_invalid_input():

def test_get_entities_in_file():
e = get_entities_in_file(WORKFLOW_FILE, False)
assert e.workflows == ["my_wf"]
assert e.tasks == ["get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"]
assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"]
assert e.workflows == ["my_wf", "wf_with_none"]
assert e.tasks == ["get_subset_df", "print_all", "show_sd", "task_with_optional", "test_union1", "test_union2"]
assert e.all() == [
"my_wf",
"wf_with_none",
"get_subset_df",
"print_all",
"show_sd",
"task_with_optional",
"test_union1",
"test_union2",
]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -236,7 +242,6 @@ def test_list_default_arguments(wf_path):
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


Expand Down Expand Up @@ -329,3 +334,26 @@ def check_image(*args, **kwargs):
mock_remote.register_script.side_effect = check_image

run_command(mock_click_ctx, tk)()


@pytest.mark.parametrize("a_val", ["foo", "1", None])
def test_pyflyte_run_with_none(a_val):
runner = CliRunner()
args = [
"run",
WORKFLOW_FILE,
"wf_with_none",
]
if a_val is not None:
args.extend(["--a", a_val])
result = runner.invoke(
pyflyte.main,
args,
catch_exceptions=False,
)
output = result.stdout.strip().split("\n")[-1].strip()
if a_val is None:
assert output == "default"
else:
assert output == a_val
assert result.exit_code == 0
10 changes: 10 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,13 @@ def my_wf(
show_sd(in_sd=image)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p)
return x


@task
def task_with_optional(a: typing.Optional[str]) -> str:
return "default" if a is None else a


@workflow
def wf_with_none(a: typing.Optional[str] = None) -> str:
return task_with_optional(a=a)

0 comments on commit ae23ddb

Please sign in to comment.