Skip to content

Commit

Permalink
test(default-args): Add more tests according to Yee's recommendation
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed May 26, 2024
1 parent cb7d72e commit 545adbc
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 47 deletions.
51 changes: 35 additions & 16 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Type, Union, cast, get_args, get_origin

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol, get_args
from typing_extensions import Protocol

from flytekit.core import constants as _common_constants
from flytekit.core import context_manager as _flyte_context
Expand Down Expand Up @@ -37,6 +37,33 @@
from flytekit.models.types import SimpleType


def _is_none_valid_instance_of(t: Type) -> bool:
"""
Returns whether None is a valid instance of t.
>>> _is_none_valid_instance_of(type(None))
True
>>> _is_none_valid_instance_of(type(int))
False
>>> _is_none_valid_instance_of(Union[int, str])
False
>>> _is_none_valid_instance_of(Optional[int])
True
>>> _is_none_valid_instance_of(Optional[List[int]])
True
>>> _is_none_valid_instance_of(Union[str, Optional[int]])
True
"""
if get_origin(t) is None:
return t is type(None)
return any(_is_none_valid_instance_of(arg) for arg in get_args(t))


def translate_inputs_to_literals(
ctx: FlyteContext,
incoming_values: Dict[str, Any],
Expand Down Expand Up @@ -1058,27 +1085,19 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation or the type annotation is None")
if k not in kwargs:
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
val, _default = interface.inputs_with_defaults[k]
if _default is not None:
raise ValueError(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if is_optional:
continue
if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None:
if k in interface.inputs_with_defaults and (
interface.inputs_with_defaults[k][1] is not None
or _is_none_valid_instance_of(interface.inputs_with_defaults[k][0])
):
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, typing.Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"

raise _user_exceptions.FlyteAssertion(error_msg)
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
Expand Down
14 changes: 0 additions & 14 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, List, NamedTuple, Optional, Union

import pytest

from flytekit.core import launch_plan
from flytekit.core.task import task
from flytekit.core.workflow import workflow
Expand Down Expand Up @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]:
return t2(a=a)

assert wf() is None

with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"):

@task()
def t3(c: Optional[int] = 3) -> Optional[int]:
...

@workflow
def wf():
return t3()

wf()
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ def ranged_int_to_str(a: int) -> typing.List[str]:
assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"]


@pytest.mark.parametrize(
"input_val,output_val",
[
(4, 0),
(5, 5),
],
)
def test_dynamic_local_default_args_task(input_val, output_val):
@task
def t1(a: int = 0) -> int:
return a

@dynamic
def dt(a: int) -> int:
if a % 2 == 0:
return t1()
return t1(a=a)

assert dt(a=input_val) == output_val

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val})
dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec.nodes) == 1
assert len(dynamic_job_spec.tasks) == 1


def test_nested_dynamic_local():
@task
def t1(a: int) -> str:
Expand Down
Loading

0 comments on commit 545adbc

Please sign in to comment.