Skip to content

Commit

Permalink
Add is_map_task to _dispatch_execute
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Nov 21, 2024
1 parent 32c5896 commit 12e194a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
19 changes: 17 additions & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _dispatch_execute(
load_task: Callable[[], PythonTask],
inputs_path: str,
output_prefix: str,
is_map_task: bool = False,
):
"""
Dispatches execute to PythonTask
Expand All @@ -145,6 +146,12 @@ def _dispatch_execute(
a: [Optional] Record outputs to output_prefix
b: OR if IgnoreOutputs is raised, then ignore uploading outputs
c: OR if an unhandled exception is retrieved - record it as an errors.pb
:param ctx: FlyteContext
:param load_task: Callable[[], PythonTask]
:param inputs: Where to read inputs
:param output_prefix: Where to write primitive outputs
:param is_map_task: Whether this task is executing as part of a map task
"""
error_file_name = _build_error_file_name()
worker_name = _get_worker_name()
Expand Down Expand Up @@ -206,6 +213,14 @@ def _dispatch_execute(

if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size:
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")
inferred_type = task_def.interface.outputs[k].type

# In the case of map tasks we need to use the type of the collection as inferred type as the task
# typed interface of the offloaded literal. This is done because the map task interface present in
# the task template contains the (correct) type for the entire map task, not the single node execution.
# For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal.
if is_map_task:
inferred_type = inferred_type.collection_type

# This file will hold the offloaded literal and will be written to the output prefix
# alongside the regular outputs.pb, deck.pb, etc.
Expand All @@ -216,7 +231,7 @@ def _dispatch_execute(
uri=f"{output_prefix}/{offloaded_filename}",
size_bytes=lit.ByteSize(),
# TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged
inferred_type=task_def.interface.outputs[k].type,
inferred_type=inferred_type,
),
hash=v.hash if v.hash is not None else compute_hash_string(lit),
)
Expand Down Expand Up @@ -633,7 +648,7 @@ def load_task():
)
return

_dispatch_execute(ctx, load_task, inputs, output_prefix)
_dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True)


def normalize_inputs(
Expand Down
90 changes: 90 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flytekit.bin.entrypoint import _dispatch_execute, get_container_error_timestamp, normalize_inputs, setup_execution, get_traceback_str
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import mock_stats
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.hash import HashMethod
from flytekit.models.core import identifier as id_models
from flytekit.core import context_manager
Expand Down Expand Up @@ -893,3 +894,92 @@ def t1(a: typing.List[int]) -> typing.List[typing.List[str]]:
assert lit.literals["o0"].HasField("offloaded_metadata") == False
else:
assert False, f"Unexpected file {ff}"



def test_dispatch_execute_offloaded_map_task(tmp_path_factory):
@task
def t1(n: int) -> int:
return n + 1

inputs: typing.List[int] = [1, 2, 3, 4]
for i, v in enumerate(inputs):
inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with get_flyte_context(tmp_path_factory, outputs_path) as ctx:
input_literal_map = _literal_models.LiteralMap(
{
"n": TypeEngine.to_literal(ctx, inputs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])),
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

with mock.patch.dict(
os.environ,
{
"_F_L_MIN_SIZE_MB": "0", # Always offload
"BATCH_JOB_ARRAY_INDEX_OFFSET": str(i),
}):
_dispatch_execute(ctx, lambda: ArrayNodeMapTask(python_function_task=t1), str(inputs_path/"inputs.pb"), str(outputs_path.absolute()), is_map_task=True)

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].HasField("offloaded_metadata") == True
assert lit.literals["o0"].offloaded_metadata.uri.endswith("/o0_offloaded_metadata.pb")
assert lit.literals["o0"].offloaded_metadata.inferred_type == LiteralType(simple=SimpleType.INTEGER).to_flyte_idl()
elif ff == "o0_offloaded_metadata.pb":
lit = literals_pb2.Literal()
lit.ParseFromString(f.read())
expected_output = v + 1
assert lit == TypeEngine.to_literal(ctx, expected_output, int, TypeEngine.to_literal_type(int)).to_flyte_idl()
else:
assert False, f"Unexpected file {ff}"


def test_dispatch_execute_offloaded_nested_lists_of_literals_offloading_disabled(tmp_path_factory):
@task
def t1(a: typing.List[int]) -> typing.List[typing.List[str]]:
return [[f"string is: {x}" for x in a] for _ in range(len(a))]

inputs_path = tmp_path_factory.mktemp("inputs")
outputs_path = tmp_path_factory.mktemp("outputs")

ctx = context_manager.FlyteContext.current_context()
with get_flyte_context(tmp_path_factory, outputs_path) as ctx:
xs: typing.List[int] = [1, 2, 3]
input_literal_map = _literal_models.LiteralMap(
{
"a": TypeEngine.to_literal(ctx, xs, typing.List[int], TypeEngine.to_literal_type(typing.List[int])),
}
)

write_proto_to_file(input_literal_map.to_flyte_idl(), str(inputs_path/"inputs.pb"))

# Ensure that this is not set by an external source
assert os.environ.get("_F_L_MIN_SIZE_MB") is None

# Notice how we're setting the env var to None, which disables offloading completely
_dispatch_execute(ctx, lambda: t1, str(inputs_path/"inputs.pb"), str(outputs_path.absolute()))

assert "error.pb" not in os.listdir(outputs_path)

for ff in os.listdir(outputs_path):
with open(outputs_path/ff, "rb") as f:
if ff == "outputs.pb":
lit = literals_pb2.LiteralMap()
lit.ParseFromString(f.read())
assert len(lit.literals) == 1
assert "o0" in lit.literals
assert lit.literals["o0"].HasField("offloaded_metadata") == False
else:
assert False, f"Unexpected file {ff}"

0 comments on commit 12e194a

Please sign in to comment.