diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index e621a786b4..a3455a9fb1 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -23,6 +23,7 @@ from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql, Task from flytekit.tools.module_loader import load_object_from_module +from flytekit.types.pickle.pickle import FlytePickleTransformer class ArrayNodeMapTask(PythonTask): @@ -55,6 +56,11 @@ def __init__( else: actual_task = python_function_task + for _, v in actual_task.python_interface.inputs.items(): + transformer = TypeEngine.get_transformer(v) + if isinstance(transformer, FlytePickleTransformer): + raise ValueError("Pickle transformers are not supported in map tasks.") + # TODO: add support for other Flyte entities if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)): raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.") @@ -224,6 +230,7 @@ def _literal_map_to_python_input( inputs_interface = self._run_task.python_interface.inputs for k in self.interface.inputs.keys(): v = literal_map.literals[k] + if k not in self.bound_inputs: # assert that v.collection is not None if not v.collection or not isinstance(v.collection.literals, list): diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 9a994d8eee..ff428b5c6d 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -74,6 +74,15 @@ def say_hello(name: str) -> str: assert res.literals["o0"].scalar.primitive.string_value == "hello earth!" +def test_map_task_with_pickle(): + @task + def say_hello(name: typing.Any) -> str: + return f"hello {name}!" + + with pytest.raises(ValueError, match="Pickle transformers are not supported in map tasks"): + map_task(say_hello)(name=["abc", "def"]) + + def test_serialization(serialization_settings): @task def t1(a: int) -> int: