Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload literals #2872

Merged
merged 28 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
12f9edf
wip - Implement offloading of literals
eapolinario Oct 18, 2024
5fefa85
Fix use of metadata bucket prefix
eapolinario Oct 18, 2024
1668fda
Fix repeated use of uri
eapolinario Oct 18, 2024
853df62
Add temporary representation for offloaded literal
eapolinario Oct 18, 2024
5e53a1b
Add one unit test
eapolinario Oct 22, 2024
177368d
Add another test
eapolinario Oct 22, 2024
5fc2e84
Stylistic changes to the two tests
eapolinario Oct 22, 2024
db48d18
Add test for min offloading threshold set to 1MB
eapolinario Oct 28, 2024
6884ee0
Pick a unique engine-dir for tests
eapolinario Oct 28, 2024
5a6423c
s/new_outputs/literal_map_copy/
eapolinario Oct 28, 2024
dbfea93
Remove unused constant
eapolinario Oct 28, 2024
adeed34
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Oct 28, 2024
25908d9
Use output_prefix in definition of offloaded literals
eapolinario Oct 29, 2024
e827693
Add initial version of pbhash.py
eapolinario Nov 6, 2024
e0e2016
Add tests to verify that overriding the hash is carried over to offlo…
eapolinario Nov 7, 2024
b284492
Add a few more tests
eapolinario Nov 7, 2024
b579b83
Always import ParamSpec from `typing_extensions`
eapolinario Nov 7, 2024
8c2336e
Fix lint warnings
eapolinario Nov 7, 2024
c28d537
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 7, 2024
37b2bb4
Set inferred_type using the task type interface
eapolinario Nov 8, 2024
e25496d
Add comment about offloaded literals files and how they are uploaded …
eapolinario Nov 8, 2024
9276e96
Add offloading_enabled
eapolinario Nov 18, 2024
a8bdbca
Add more unit tests including a negative test
eapolinario Nov 19, 2024
fe822b9
Merge remote-tracking branch 'origin' into offload-literals
eapolinario Nov 19, 2024
a4fcfab
Fix bad merge
eapolinario Nov 19, 2024
b3a1b0d
Incorporate feedback.
eapolinario Nov 19, 2024
32c5896
Fix image name (unrelated to this PR - just a nice-to-have to decreas…
eapolinario Nov 21, 2024
12e194a
Add `is_map_task` to `_dispatch_execute`
eapolinario Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid
import warnings
from sys import exit
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional

import click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand Down Expand Up @@ -55,6 +55,7 @@
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
from flytekit.tools.module_loader import load_object_from_module
from flytekit.utils.pbhash import compute_hash_string


def get_version_message():
Expand Down Expand Up @@ -135,6 +136,7 @@ def _dispatch_execute(
load_task: Callable[[], PythonTask],
inputs_path: str,
output_prefix: str,
is_map_task: bool = False,
):
"""
Dispatches execute to PythonTask
Expand All @@ -144,6 +146,12 @@ def _dispatch_execute(
a: [Optional] Record outputs to output_prefix
b: OR if IgnoreOutputs is raised, then ignore uploading outputs
c: OR if an unhandled exception is retrieved - record it as an errors.pb

:param ctx: FlyteContext
:param load_task: Callable[[], PythonTask]
:param inputs: Where to read inputs
:param output_prefix: Where to write primitive outputs
:param is_map_task: Whether this task is executing as part of a map task
"""
error_file_name = _build_error_file_name()
worker_name = _get_worker_name()
Expand Down Expand Up @@ -179,7 +187,59 @@ def _dispatch_execute(
logger.warning("Task produces no outputs")
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
# The keys in this map hold the filenames to the offloaded proto literals.
offloaded_literals: Dict[str, _literal_models.Literal] = {}
literal_map_copy = {}

offloading_enabled = os.environ.get("_F_L_MIN_SIZE_MB", None) is not None
min_offloaded_size = -1
max_offloaded_size = -1
if offloading_enabled:
min_offloaded_size = int(os.environ.get("_F_L_MIN_SIZE_MB", "10")) * 1024 * 1024
max_offloaded_size = int(os.environ.get("_F_L_MAX_SIZE_MB", "1000")) * 1024 * 1024

# Go over each output and create a separate offloaded in case its size is too large
for k, v in outputs.literals.items():
literal_map_copy[k] = v

if not offloading_enabled:
continue

lit = v.to_flyte_idl()
if max_offloaded_size != -1 and lit.ByteSize() >= max_offloaded_size:
raise ValueError(
f"Literal {k} is too large to be offloaded. Max literal size is {max_offloaded_size} whereas the literal size is {lit.ByteSize()} bytes"
)

if min_offloaded_size != -1 and lit.ByteSize() >= min_offloaded_size:
logger.debug(f"Literal {k} is too large to be inlined, offloading to metadata bucket")
inferred_type = task_def.interface.outputs[k].type

# In the case of map tasks we need to use the type of the collection as inferred type as the task
# typed interface of the offloaded literal. This is done because the map task interface present in
# the task template contains the (correct) type for the entire map task, not the single node execution.
# For that reason we "unwrap" the collection type and use it as the inferred type of the offloaded literal.
if is_map_task:
inferred_type = inferred_type.collection_type

# This file will hold the offloaded literal and will be written to the output prefix
# alongside the regular outputs.pb, deck.pb, etc.
# N.B.: by construction `offloaded_filename` is guaranteed to be unique
offloaded_filename = f"{k}_offloaded_metadata.pb"
offloaded_literal = _literal_models.Literal(
offloaded_metadata=_literal_models.LiteralOffloadedMetadata(
uri=f"{output_prefix}/{offloaded_filename}",
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
size_bytes=lit.ByteSize(),
# TODO: remove after https://github.com/flyteorg/flyte/pull/5909 is merged
inferred_type=inferred_type,
),
hash=v.hash if v.hash is not None else compute_hash_string(lit),
)
literal_map_copy[k] = offloaded_literal
offloaded_literals[offloaded_filename] = v
outputs = _literal_models.LiteralMap(literals=literal_map_copy)

output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs, **offloaded_literals}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
else:
Expand Down Expand Up @@ -588,7 +648,7 @@ def load_task():
)
return

_dispatch_execute(ctx, load_task, inputs, output_prefix)
_dispatch_execute(ctx, load_task, inputs, output_prefix, is_map_task=True)


def normalize_inputs(
Expand Down
8 changes: 2 additions & 6 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core.utils import str2bool

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
Expand All @@ -20,6 +15,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.core.utils import str2bool
from flytekit.deck import DeckField
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
Expand Down
6 changes: 1 addition & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
from functools import update_wrapper
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import ParamSpec # type: ignore
from typing_inspect import is_optional_type

try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore

from flytekit.core import constants as _common_constants
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import PythonTask, Task
Expand Down
3 changes: 3 additions & 0 deletions flytekit/interaction/string_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def literal_string_repr(lit: Literal) -> typing.Any:
return [literal_string_repr(i) for i in lit.collection.literals]
if lit.map:
return {k: literal_string_repr(v) for k, v in lit.map.literals.items()}
if lit.offloaded_metadata:
# TODO: load literal from offloaded literal?
return f"Offloaded literal metadata: {lit.offloaded_metadata}"
Comment on lines +64 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example we can try to see if we need to load the literal here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is relevant for the pyflyte fetch command:

❯ pyflyte --config ~/.flyte/config-sandbox.yaml fetch flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o
Fetching data from flyte://v1/flytesnacks/development/asvskwn766f5v492pzgt/n0-0-n1/o...
╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ {                                                                                                                                                                                                                                                                                             │
│     'o0': [                                                                                                                                                                                                                                                                                   │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015',                                                                                                │
...
│         'Offloaded literal metadata: Flyte Serialized object (LiteralOffloadedMetadata):\n  uri: s3://my-s3-bucket/metadata/propeller/flytesnacks-development- [...]\n  size_bytes: 39936015'                                                                                                 │
│     ]                                                                                                                                                                                                                                                                                         │
│ }                                                                                                                                                                                                                                                                                             │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

raise ValueError(f"Unknown literal type {lit}")


Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def to_flyte_idl(self):
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata is not None else None,
)

@classmethod
Expand Down
39 changes: 39 additions & 0 deletions flytekit/utils/pbhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This is a module that provides hashing utilities for Protobuf objects.
import base64
import hashlib
import json

from google.protobuf import json_format
from google.protobuf.message import Message


def compute_hash(pb: Message) -> bytes:
"""
Computes a deterministic hash in bytes for the Protobuf object.
"""
try:
pb_dict = json_format.MessageToDict(pb)
# json.dumps with sorted keys to ensure stability
stable_json_str = json.dumps(
pb_dict, sort_keys=True, separators=(",", ":")
) # separators to ensure no extra spaces
except Exception as e:
raise ValueError(f"Failed to marshal Protobuf object {pb} to JSON with error: {e}")

try:
# Deterministically hash the JSON object to a byte array. Using SHA-256 for hashing here,
# assuming it provides a consistent hash output.
hash_obj = hashlib.sha256(stable_json_str.encode("utf-8"))
except Exception as e:
raise ValueError(f"Failed to hash JSON for Protobuf object {pb} with error: {e}")

# The digest is guaranteed to be 32 bytes long
return hash_obj.digest()


def compute_hash_string(pb: Message) -> str:
"""
Computes a deterministic hash in base64 encoded string for the Protobuf object
"""
hash_bytes = compute_hash(pb)
return base64.b64encode(hash_bytes).decode("utf-8")
Loading
Loading