diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index de4b94f9f8..ae73643e67 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -5,10 +5,10 @@ import typing from copy import deepcopy from enum import Enum -from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Type, Union, cast, get_args, get_origin from google.protobuf import struct_pb2 as _struct -from typing_extensions import Protocol, get_args +from typing_extensions import Protocol from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -37,6 +37,33 @@ from flytekit.models.types import SimpleType +def _is_none_valid_instance_of(t: Type) -> bool: + """ + Returns whether None is a valid instance of t. + + >>> _is_none_valid_instance_of(type(None)) + True + + >>> _is_none_valid_instance_of(type(int)) + False + + >>> _is_none_valid_instance_of(Union[int, str]) + False + + >>> _is_none_valid_instance_of(Optional[int]) + True + + >>> _is_none_valid_instance_of(Optional[List[int]]) + True + + >>> _is_none_valid_instance_of(Union[str, Optional[int]]) + True + """ + if get_origin(t) is None: + return t is type(None) + return any(_is_none_valid_instance_of(arg) for arg in get_args(t)) + + def translate_inputs_to_literals( ctx: FlyteContext, incoming_values: Dict[str, Any], @@ -1058,27 +1085,19 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] + if var.type.simple == SimpleType.NONE: + raise TypeError("Arguments do not have type annotation or the type annotation is None") if k not in kwargs: - is_optional = False - if var.type.union_type: - for variant in var.type.union_type.variants: - if variant.simple == SimpleType.NONE: - val, _default = interface.inputs_with_defaults[k] - if _default is not None: - raise ValueError( - f"The default value for the optional type must be None, but got {_default}" - ) - is_optional = True - if is_optional: - continue - if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None: + if k in interface.inputs_with_defaults and ( + interface.inputs_with_defaults[k][1] is not None + or _is_none_valid_instance_of(interface.inputs_with_defaults[k][0]) + ): default_val = interface.inputs_with_defaults[k][1] if not isinstance(default_val, typing.Hashable): raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument") kwargs[k] = default_val else: error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}" - raise _user_exceptions.FlyteAssertion(error_msg) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 6fe2b01e61..0073baec53 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,7 +1,5 @@ from typing import Dict, List, NamedTuple, Optional, Union -import pytest - from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]: return t2(a=a) assert wf() is None - - with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): - - @task() - def t3(c: Optional[int] = 3) -> Optional[int]: - ... - - @workflow - def wf(): - return t3() - - wf() diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 7964548674..bda4703065 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -95,6 +95,42 @@ def ranged_int_to_str(a: int) -> typing.List[str]: assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"] +@pytest.mark.parametrize( + "input_val,output_val", + [ + (4, 0), + (5, 5), + ], +) +def test_dynamic_local_default_args_task(input_val, output_val): + @task + def t1(a: int = 0) -> int: + return a + + @dynamic + def dt(a: int) -> int: + if a % 2 == 0: + return t1() + return t1(a=a) + + assert dt(a=input_val) == output_val + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val}) + dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map) + assert len(dynamic_job_spec.nodes) == 1 + assert len(dynamic_job_spec.tasks) == 1 + + def test_nested_dynamic_local(): @task def t1(a: int) -> str: diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index ba6cb0588d..a7e6fc72b8 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -1,8 +1,10 @@ import os import typing from collections import OrderedDict +from copy import deepcopy import mock +import pandas as pd import pytest import flytekit.configuration @@ -14,7 +16,23 @@ from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.models.admin.workflow import WorkflowSpec -from flytekit.models.types import SimpleType +from flytekit.models.core import types as _core_types +from flytekit.models.literals import ( + BindingData, + BindingDataCollection, + BindingDataMap, + Blob, + BlobMetadata, + Literal, + Primitive, + Scalar, + StructuredDataset, + StructuredDatasetMetadata, + StructuredDatasetType, + Union, + Void, +) +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.tools.translator import get_serializable from flytekit.types.error.error import FlyteError @@ -28,6 +46,25 @@ ) +def remove_uri(obj: typing.Any): + """ + Set the nested "_uri" attribute in an object to empty string. + """ + + def _remove_uri(obj: typing.Any): + if type(obj).__module__ == "builtins": + return obj + for k, v in obj.__dict__.items(): + if k == "_uri": + new_v = "" + else: + new_v = _remove_uri(v) + setattr(obj, k, new_v) + return obj + + return _remove_uri(deepcopy(obj)) + + def test_serialization(): square = ContainerTask( name="square", @@ -86,36 +123,284 @@ def raw_container_wf(val1: int, val2: int) -> int: @pytest.mark.parametrize( - "arg_type,default_arg,input_arg,error", + "arg_type,default_arg,input_arg,with_input_error,no_input_error,default_arg_binding,input_arg_binding,output_type,custom_assert_equal", [ - (int | None, None, 100, None), - (int | None, 100, 200, ValueError), - (int, 0, 5, None), - (str, "", "a", None), - (list[int], [], [1, 2, 3], FlyteAssertion), - (dict[str, int], {}, {"a": 1}, FlyteAssertion), + ( + typing.Optional[int], + None, + 100, + None, + None, + Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ), + Scalar(primitive=Primitive(integer=100)), + LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ), + None, + ), + ( + typing.Optional[int], + 100, + 200, + None, + None, + Scalar(primitive=Primitive(integer=100)), + Scalar(primitive=Primitive(integer=200)), + LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ), + None, + ), + ( + int, + 0, + 5, + None, + None, + Scalar(primitive=Primitive(integer=0)), + Scalar(primitive=Primitive(integer=5)), + LiteralType(simple=SimpleType.INTEGER), + None, + ), + ( + str, + "", + "a", + None, + None, + Scalar(primitive=Primitive(string_value="")), + Scalar(primitive=Primitive(string_value="a")), + LiteralType(simple=SimpleType.STRING), + None, + ), + ( + int, + "five", + "great", + AssertionError, + AssertionError, + Scalar(primitive=Primitive(string_value="five")), + Scalar(primitive=Primitive(string_value="a")), + LiteralType(simple=SimpleType.STRING), + None, + ), + ( + list[int], + [], + [1, 2], + None, + FlyteAssertion, + BindingDataCollection(bindings=[]), + BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + ], + ), + LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER)), + None, + ), + ( + dict[str, int], + {}, + {"a": 1}, + None, + FlyteAssertion, + BindingDataMap(bindings={}), + BindingDataMap( + bindings={ + "a": BindingData( + scalar=Scalar(primitive=Primitive(integer=1)), + ), + }, + ), + LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)), + None, + ), + ( + typing.Any, # test FlytePickle + 12, + "hello", + None, + None, + Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="PythonPickle", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="", + ), + ), + Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="PythonPickle", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ), + ), + uri="", + ), + ), + LiteralType( + blob=_core_types.BlobType( + format="PythonPickle", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ), + metadata={ + "python_class_name": "typing.Any", + }, + ), + None, + ), + ( + pd.DataFrame, + pd.DataFrame(), + pd.DataFrame( + { + "strs": ["a", "b", "c"], + "ints": [1, 2, 3], + } + ), + None, + FlyteAssertion, + Scalar( + structured_dataset=StructuredDataset( + uri="", + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + format="parquet", + ), + ), + ), + ), + Scalar( + structured_dataset=StructuredDataset( + uri="", + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + format="parquet", + ), + ), + ), + ), + LiteralType(structured_dataset_type=StructuredDatasetType()), + pd.testing.assert_frame_equal, + ), ], ) -def test_default_args_task(arg_type, default_arg, input_arg, error): +def test_default_args_task( + arg_type, + default_arg, + input_arg, + with_input_error, + no_input_error, + default_arg_binding, + input_arg_binding, + output_type, + custom_assert_equal, +): @task def t1(a: arg_type = default_arg) -> arg_type: return a - @workflow - def wf_no_input() -> arg_type: - return t1() - + # Test task with input parameter @workflow def wf_with_input() -> arg_type: return t1(a=input_arg) - if error: - with pytest.raises(error): + if with_input_error: + with pytest.raises(with_input_error): + wf_with_input() + else: + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + assert remove_uri(wf_with_input_spec.template.nodes[0].inputs[0].binding.value) == input_arg_binding + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + if custom_assert_equal: + custom_assert_equal(wf_with_input(), input_arg) + else: + assert wf_with_input() == input_arg + + # Test task with no input parameter (use the value in the task's default argument) + @workflow + def wf_no_input() -> arg_type: + return t1() + + if no_input_error: + with pytest.raises(no_input_error): wf_no_input() else: - assert wf_no_input() == default_arg + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + assert remove_uri(wf_no_input_spec.template.nodes[0].inputs[0].binding.value) == default_arg_binding + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + if custom_assert_equal: + custom_assert_equal(wf_no_input(), default_arg) + else: + assert wf_no_input() == default_arg + + # Test sub-workflow + @workflow + def wf() -> tuple[arg_type, arg_type]: + return (wf_no_input(), wf_with_input()) + + if not no_input_error: + assert wf() == (default_arg, input_arg) + + +def test_default_args_task_no_type_hint(): + @task + def t1(a=0) -> int: + return a + + @workflow + def wf_with_input() -> int: + return t1(a=2) + + @workflow + def wf_no_input() -> int: + return t1() - assert wf_with_input() == input_arg + with pytest.raises(TypeError): + wf_with_input() + with pytest.raises(TypeError): + wf_no_input() def test_serialization_branch_complex():