Skip to content

Commit

Permalink
Offload literals (#2872)
Browse files Browse the repository at this point in the history
* wip - Implement offloading of literals

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix use of metadata bucket prefix

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix repeated use of uri

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add temporary representation for offloaded literal

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add one unit test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add another test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Stylistic changes to the two tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add test for min offloading threshold set to 1MB

Signed-off-by: Eduardo Apolinario <[email protected]>

* Pick a unique engine-dir for tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* s/new_outputs/literal_map_copy/

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove unused constant

Signed-off-by: Eduardo Apolinario <[email protected]>

* Use output_prefix in definition of offloaded literals

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add initial version of pbhash.py

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add tests to verify that overriding the hash is carried over to offloaded literals

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add a few more tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Always import ParamSpec from `typing_extensions`

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix lint warnings

Signed-off-by: Eduardo Apolinario <[email protected]>

* Set inferred_type using the task type interface

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add comment about offloaded literals files and how they are uploaded to the metadata bucket

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add offloading_enabled

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add more unit tests including a negative test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix bad merge

Signed-off-by: Eduardo Apolinario <[email protected]>

* Incorporate feedback.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix image name (unrelated to this PR - just a nice-to-have to decrease flakiness)

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add `is_map_task` to `_dispatch_execute`

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Nov 22, 2024
1 parent 2e40e76 commit 40c9540
Show file tree
Hide file tree
Showing 9 changed files with 717 additions and 18 deletions.
66 changes: 63 additions & 3 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid
import warnings
from sys import exit
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional

import click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand Down Expand Up @@ -55,6 +55,7 @@
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
from flytekit.tools.module_loader import load_object_from_module
from flytekit.utils.pbhash import compute_hash_string


def get_version_message():
Expand Down Expand Up @@ -135,6 +136,7 @@ def _dispatch_execute(
load_task: Callable[[], PythonTask],
inputs_path: str,
output_prefix: str,
is_map_task: bool = False,
):
"""
Dispatches execute to PythonTask
Expand All @@ -144,6 +146,12 @@ def _dispatch_execute(
a: [Optional] Record outputs to output_prefix
b: OR if IgnoreOutputs is raised, then ignore uploading outputs
c: OR if an unhandled exception is retrieved - record it as an errors.pb
:param ctx: FlyteContext
:param load_task: Callable[[], PythonTask]
:param inputs: Where to read inputs
:param output_prefix: Where to write primitive outputs
:param is_map_task: Whether this task is executing as part of a map task
"""
error_file_name = _build_error_file_name()
worker_name = _get_worker_name()
Expand Down Expand Up @@ -179,7 +187,59 @@ def _dispatch_execute(
logger.warning("Task produces no outputs")
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
# The keys in this map hold the filenames to the offloaded proto literals.
offloaded_literals: Dict[str, _literal_models.Literal] = {}
literal_map_copy = {}

offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None
min_offloaded_size = -1
max_offloaded_size = -1
if offloading_enabled:
min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024
max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024

# Go over each output and create a separate offloaded in case its size is too large
for k, v in outputs.literals.items():
literal_map_copy[k] = v

if not offloading_enabled:
continue

lit = v.to_flyte_idl()
if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size:
raise ValueError(
f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes"
)

if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size:
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")
inferred_type = task_def.interface.outputs[k].type

# In the case of map tasks we need to use the type of the collection as inferred type as the task
# typed interface of the offloaded literal. This is done because the map task interface present in
# the task template contains the (correct) type for the entire map task, not the single node execution.
# For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal.
if is_map_task:
inferred_type = inferred_type.collection_type

# This file will hold the offloaded literal and will be written to the output prefix
# alongside the regular outputs.pb, deck.pb, etc.
# N.B.: by construction `offloaded_filename` is guaranteed to be unique
offloaded_filename = f"{k}_offloaded_metadata.pb"
offloaded_literal = _literal_models.Literal(
offloaded_metadata=_literal_models.LiteralOffloadedMetadata(
uri=f"{output_prefix}/{offloaded_filename}",
size_bytes=lit.ByteSize(),
# TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged
inferred_type=inferred_type,
),
hash=v.hash if v.hash is not None else compute_hash_string(lit),
)
literal_map_copy[k] = offloaded_literal
offloaded_literals[offloaded_filename] = v
outputs = _literal_models.LiteralMap(literals=literal_map_copy)

output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
else:
Expand Down Expand Up @@ -588,7 +648,7 @@ def load_task():
)
return

_dispatch_execute(ctx, load_task, inputs, output_prefix)
_dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True)


def normalize_inputs(
Expand Down
8 changes: 2 additions & 6 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core.utils import str2bool

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
Expand All @@ -20,6 +15,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.core.utils import str2bool
from flytekit.deck import DeckField
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
Expand Down
6 changes: 1 addition & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import ParamSpec # type: ignore
from typing_inspect import is_optional_type

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import constants as _common_constants
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import PythonTask, Task
Expand Down
3 changes: 3 additions & 0 deletions flytekit/interaction/string_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any:
return [literal_string_repr(i) for i in lit.collection.literals]
if lit.map:
return {k: literal_string_repr(v) for k, v in lit.map.literals.items()}
if lit.offloaded_metadata:
# TODO: load literal from offloaded literal?
return f"Offloaded literal metadata: {lit.offloaded_metadata}"
raise ValueError(f"Unknown literal type {lit}")


Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def to_flyte_idl(self):
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None,
)

@classmethod
Expand Down
39 changes: 39 additions & 0 deletions flytekit/utils/pbhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This is a module that provides hashing utilities for Protobuf objects.
import base64
import hashlib
import json

from google.protobuf import json_format
from google.protobuf.message import Message


def compute_hash(pb: Message) -> bytes:
"""
Computes a deterministic hash in bytes for the Protobuf object.
"""
try:
pb_dict = json_format.MessageToDict(pb)
# json.dumps with sorted keys to ensure stability
stable_json_str = json.dumps(
pb_dict, sort_keys=True, separators=(",", ":")
) # separators to ensure no extra spaces
except Exception as e:
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")

try:
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
# assuming it provides a consistent hash output.
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
except Exception as e:
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")

# The digest is guaranteed to be 32 bytes long
return hash_obj.digest()


def compute_hash_string(pb: Message) -> str:
"""
Computes a deterministic hash in base64 encoded string for the Protobuf object
"""
hash_bytes = compute_hash(pb)
return base64.b64encode(hash_bytes).decode("utf-8")
Loading

0 comments on commit 40c9540

Please sign in to comment.