Skip to content

Commit

Permalink
Only transform the literal that matches the map task index (#2484)
Browse files Browse the repository at this point in the history
* Only transform the literal that matches the map task index

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* kevin's update

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* make it work locally

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* lint

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* add a test

Signed-off-by: Kevin Su <[email protected]>

* check (#2486)

Signed-off-by: Kevin Su <[email protected]>

* Only fail for unbound batch pickled inputs for map tasks

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* Unit tests

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* lint

Signed-off-by: Haytham Abuelfutuh <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent a319d49 commit a1a9235
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 14 deletions.
57 changes: 43 additions & 14 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import os # TODO: use flytekit logger
import typing
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast

Expand All @@ -13,14 +14,18 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, is_annotated
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.array_job import ArrayJob
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql, Task
from flytekit.tools.module_loader import load_object_from_module
from flytekit.types.pickle import pickle
from flytekit.types.pickle.pickle import FlytePickleTransformer


class ArrayNodeMapTask(PythonTask):
Expand Down Expand Up @@ -57,6 +62,16 @@ def __init__(
if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)):
raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.")

for k, v in actual_task.python_interface.inputs.items():
if bound_inputs and k in bound_inputs:
continue
transformer = TypeEngine.get_transformer(v)
if isinstance(transformer, FlytePickleTransformer):
if is_annotated(v):
for annotation in typing.get_args(v)[1:]:
if isinstance(annotation, pickle.BatchSize):
raise ValueError("Choosing a BatchSize for map tasks inputs is not supported.")

n_outputs = len(actual_task.python_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")
Expand Down Expand Up @@ -208,24 +223,38 @@ def __call__(self, *args, **kwargs):
kwargs = {**self._partial.keywords, **kwargs}
return super().__call__(*args, **kwargs)

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
) -> Dict[str, Any]:
ctx = FlyteContextManager.current_context()
inputs_interface = self.python_interface.inputs
inputs_map = literal_map
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution
# environment, then we should process/download/extract only the inputs that are needed for the current task.
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
map_task_inputs = {}
task_index = self._compute_array_job_index()
inputs_interface = self._run_task.python_interface.inputs
for k in self.interface.inputs.keys():
v = literal_map.literals[k]

if k not in self.bound_inputs:
# assert that v.collection is not None
if not v.collection or not isinstance(v.collection.literals, list):
raise ValueError(f"Expected a list of literals for {k}")
map_task_inputs[k] = v.collection.literals[task_index]
else:
map_task_inputs[k] = v
inputs_map = _literal_models.LiteralMap(literals=map_task_inputs)
return TypeEngine.literal_map_to_kwargs(ctx, inputs_map, inputs_interface)

def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
return self._execute_map_task(ctx, **kwargs)
return exception_scopes.user_entry_point(self.python_function_task.execute)(**kwargs)

return self._raw_execute(**kwargs)

def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any:
task_index = self._compute_array_job_index()
map_task_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
map_task_inputs[k] = v[task_index]
else:
map_task_inputs[k] = v
return exception_scopes.user_entry_point(self.python_function_task.execute)(**map_task_inputs)

@staticmethod
def _compute_array_job_index() -> int:
"""
Expand Down Expand Up @@ -276,8 +305,8 @@ def _raw_execute(self, **kwargs) -> Any:
outputs = []

mapped_tasks_count = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
if self.python_function_task.interface.inputs.items():
for k in self.python_function_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_tasks_count = len(v)
Expand Down
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from flytekit import map_task, task, workflow
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.task import TaskMetadata
from flytekit.core.type_engine import TypeEngine
from flytekit.tools.translator import get_serializable
from flytekit.types.pickle import BatchSize


@pytest.fixture
Expand Down Expand Up @@ -54,6 +57,39 @@ def wf() -> List[str]:
assert wf() == ["hello hello earth!!", "hello hello mars!!"]


def test_remote_execution(serialization_settings):
@task
def say_hello(name: str) -> str:
return f"hello {name}!"

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
) as ctx:
t = map_task(say_hello)
lm = TypeEngine.dict_to_literal_map(ctx, {"name": ["earth", "mars"]}, type_hints={"name": typing.List[str]})
res = t.dispatch_execute(ctx, lm)
assert len(res.literals) == 1
assert res.literals["o0"].scalar.primitive.string_value == "hello earth!"


def test_map_task_with_pickle():
@task
def say_hello(name: typing.Annotated[typing.Any, BatchSize(10)]) -> str:
return f"hello {name}!"

with pytest.raises(ValueError, match="Choosing a BatchSize for map tasks inputs is not supported."):
map_task(say_hello)(name=["abc", "def"])

@task
def say_hello(name: typing.Any) -> str:
return f"hello {name}!"

map_task(say_hello)(name=["abc", "def"])


def test_serialization(serialization_settings):
@task
def t1(a: int) -> int:
Expand Down

0 comments on commit a1a9235

Please sign in to comment.