diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index a7b35bc34cc..575654b57df 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -8,6 +8,8 @@ from contextlib import contextmanager from typing import Any, Dict, List, Optional, Set, Union, cast +from flyteidl.core import tasks_pb2 + from flytekit.configuration import SerializationSettings from flytekit.core import tracker from flytekit.core.base_task import PythonTask, TaskResolverMixin @@ -152,6 +154,9 @@ def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]: def bound_inputs(self) -> Set[str]: return self._bound_inputs + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: + return self.python_function_task.get_extended_resources(settings) + @contextmanager def prepare_target(self): """ diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 9b0144096e4..a8ab3a6d388 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -11,6 +11,7 @@ from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine +from flytekit.extras.accelerators import GPUAccelerator from flytekit.tools.translator import get_serializable from flytekit.types.pickle import BatchSize @@ -381,3 +382,23 @@ def wf(x: typing.List[int]): task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible + + +def test_serialization_extended_resources(serialization_settings): + @task( + accelerator=GPUAccelerator("test_gpu"), + ) + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = map_task(t1) + + @workflow + def wf(x: typing.List[int]): + return arraynode_maptask(a=x) + + od = OrderedDict() + get_serializable(od, serialization_settings, wf) + task_spec = od[arraynode_maptask] + + assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu"