Skip to content

Commit

Permalink
[BUG] support setting extended resources for array node map tasks (fl…
Browse files Browse the repository at this point in the history
  • Loading branch information
pvditt authored and Mecoli1219 committed Jul 27, 2024
1 parent 4114601 commit 8a45055
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast

from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.base_task import PythonTask, TaskResolverMixin
Expand Down Expand Up @@ -152,6 +154,9 @@ def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]:
def bound_inputs(self) -> Set[str]:
return self._bound_inputs

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
return self.python_function_task.get_extended_resources(settings)

@contextmanager
def prepare_target(self):
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.task import TaskMetadata
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.accelerators import GPUAccelerator
from flytekit.tools.translator import get_serializable
from flytekit.types.pickle import BatchSize

Expand Down Expand Up @@ -381,3 +382,23 @@ def wf(x: typing.List[int]):
task_spec = od[arraynode_maptask]
assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.metadata.interruptible


def test_serialization_extended_resources(serialization_settings):
@task(
accelerator=GPUAccelerator("test_gpu"),
)
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1)

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
get_serializable(od, serialization_settings, wf)
task_spec = od[arraynode_maptask]

assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu"

0 comments on commit 8a45055

Please sign in to comment.