Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Workflow] Serialization cleanup #18328

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
c86e683
notes
Aug 25, 2021
15116f3
notes
Aug 25, 2021
47f04d0
.
Aug 25, 2021
4bd7115
seems to work?
Aug 26, 2021
ca6a12d
.
Aug 26, 2021
1f56f2f
seems to work
Aug 27, 2021
f58c6e0
needs tests
Aug 27, 2021
a1ba90f
needs tests
Aug 27, 2021
cf69b60
parallelize uploads
Aug 28, 2021
c550c93
fixed
Aug 28, 2021
7027113
fixed
Aug 28, 2021
1e3c9da
.
Aug 31, 2021
7512214
.
Aug 31, 2021
9103dd3
dumb test
Aug 31, 2021
c534eed
.
Aug 31, 2021
1685ce4
.
Aug 31, 2021
f270cdd
fix festsg
Sep 1, 2021
1ea1583
.
Sep 1, 2021
3b0a1ef
works
Sep 2, 2021
35b2130
.:
Sep 2, 2021
f29cf18
.
Sep 2, 2021
0628188
.
Sep 2, 2021
507fc76
.
Sep 2, 2021
c394fff
Update common.py
Sep 3, 2021
1192a7c
.
Sep 3, 2021
9963141
almost removed special case for inputs
Sep 3, 2021
3807479
lint
Sep 3, 2021
b31a20f
lint
Sep 3, 2021
3280a6c
Merge branch 'workflow_objectref' of github.com:wuisawesome/ray into …
Sep 3, 2021
51e66cc
.
Sep 3, 2021
8405911
lint
Sep 3, 2021
dfc104a
Merge branch 'master' of github.com:ray-project/ray into workflow_ref…
Sep 3, 2021
028864d
handle edge case
Sep 3, 2021
77c19ce
Merge branch 'master' into workflow_refactor_serialization
Sep 3, 2021
0ec0f34
.
Sep 3, 2021
a57bda6
.
Sep 7, 2021
b5b5daf
Merge branch 'master' of github.com:ray-project/ray into workflow_ref…
Sep 7, 2021
d55e1d2
Merge branch 'master' of github.com:ray-project/ray into workflow_ref…
Sep 14, 2021
de14d43
lint
Sep 14, 2021
188028e
needs dedupe
Sep 15, 2021
4b82146
needs dedupe
Sep 15, 2021
84a4a83
still need to not leak cache
Sep 15, 2021
36ca196
still need to not leak cache
Sep 15, 2021
6a5a5a2
probably fails edge cases?
Sep 15, 2021
58e4350
probably fails edge cases?
Sep 15, 2021
1513f07
works?
Sep 16, 2021
ae27e8f
cleanup
Sep 17, 2021
cb42e1e
passes test?
Sep 17, 2021
ea521e3
???
Sep 17, 2021
573f4e0
done?
Sep 18, 2021
85f48bb
may work?
Sep 21, 2021
fc26b54
may work?
Sep 21, 2021
040a0e5
.
Sep 21, 2021
6aee406
.
Sep 21, 2021
4f4ccc5
Revert "."
Sep 21, 2021
391b4fe
Revert "."
Sep 21, 2021
79a568d
Revert "may work?"
Sep 21, 2021
32ce01d
Revert "may work?"
Sep 21, 2021
968043e
Revert "done?"
Sep 21, 2021
d02fe2e
passs tests
Sep 22, 2021
3cc26cb
lint
Sep 22, 2021
e01ec86
cleanup
Sep 22, 2021
9ba5ce7
bug fix
Sep 22, 2021
f08272c
bug fix
Sep 22, 2021
173d8fd
print
Sep 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions python/ray/workflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ class StepType(str, Enum):
class WorkflowInputs:
# The object ref of the input arguments.
args: ObjectRef
# The object refs in the arguments.
object_refs: List[ObjectRef]
# TODO(suquark): maybe later we can replace it with WorkflowData.
# The workflows in the arguments.
workflows: "List[Workflow]"
Expand All @@ -96,16 +94,6 @@ def _hash(obj: Any) -> bytes:
return m.digest()


def calculate_identifiers(object_refs: List[ObjectRef]) -> List[str]:
"""
Calculate identifiers for an object ref based on the contents. (i.e. a hash
of the contents).
"""
hashes = ray.get([_hash.remote(obj) for obj in object_refs])
encoded = map(base64.urlsafe_b64encode, hashes)
return [encoded_hash.decode("ascii") for encoded_hash in encoded]


@ray.remote
def calculate_identifier(obj: Any) -> str:
"""Calculate a url-safe identifier for an object."""
Expand Down Expand Up @@ -140,21 +128,11 @@ class WorkflowData:
# name of the step
name: str

# Cache the intended locations of object refs. These are expensive
# calculations since they require computing the hash over a large value.
_cached_refs: List[ObjectRef] = None
_cached_locs: List[str] = None

def to_metadata(self) -> Dict[str, Any]:
if self._cached_refs != self.inputs.object_refs:
self._cached_refs = self.inputs.object_refs
self._cached_locs = calculate_identifiers(self._cached_refs)

f = self.func_body
metadata = {
"name": get_module(f) + "." + get_qualname(f),
"step_type": self.step_type,
"object_refs": self._cached_locs,
"workflows": [w.step_id for w in self.inputs.workflows],
"max_retries": self.max_retries,
"workflow_refs": [wr.step_id for wr in self.inputs.workflow_refs],
Expand Down
96 changes: 32 additions & 64 deletions python/ray/workflow/recovery.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
from typing import List, Any, Union, Dict, Callable, Tuple, Optional

import ray
from ray.workflow import workflow_context
from ray.workflow import serialization
from ray.workflow.common import (Workflow, StepID, WorkflowRef,
WorkflowExecutionResult)
from ray.workflow import storage
Expand All @@ -28,42 +28,31 @@ def __init__(self, workflow_id: str):


@WorkflowStepFunction
def _recover_workflow_step(input_object_refs: List[ray.ObjectRef],
def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any],
input_workflows: List[Any],
input_workflow_refs: List[WorkflowRef],
instant_workflow_inputs: Dict[int, StepID]):
input_workflow_refs: List[WorkflowRef]):
"""A workflow step that recovers the output of an unfinished step.

Args:
input_object_refs: The object refs in the argument of
the (original) step.
args: The positional arguments for the step function.
kwargs: The keyword args for the step function.
input_workflows: The workflows in the argument of the (original) step.
They are resolved into physical objects (i.e. the output of the
workflows) here. They come from other recover workflows we
construct recursively.
instant_workflow_inputs: Same as 'input_workflows', but they come
point to workflow steps that have output checkpoints. They override
corresponding workflows in 'input_workflows'.

Returns:
The output of the recovered step.
"""
reader = workflow_storage.get_workflow_storage()
for index, _step_id in instant_workflow_inputs.items():
# override input workflows with instant workflows
input_workflows[index] = reader.load_step_output(_step_id)

step_id = workflow_context.get_current_step_id()
func: Callable = reader.load_step_func_body(step_id)
args, kwargs = reader.load_step_args(
step_id, input_workflows, input_object_refs, input_workflow_refs)
return func(*args, **kwargs)


def _construct_resume_workflow_from_step(
reader: workflow_storage.WorkflowStorage,
step_id: StepID,
objectref_cache: Dict[str, Any] = None) -> Union[Workflow, StepID]:
step_id: StepID) -> Union[Workflow, StepID]:
"""Try to construct a workflow (step) that recovers the workflow step.
If the workflow step already has an output checkpointing file, we return
the workflow step id instead.
Expand All @@ -76,61 +65,40 @@ def _construct_resume_workflow_from_step(
A workflow that recovers the step, or a ID of a step
that contains the output checkpoint file.
"""
if objectref_cache is None:
objectref_cache = {}
result: workflow_storage.StepInspectResult = reader.inspect_step(step_id)
if result.output_object_valid:
# we already have the output
return step_id
if isinstance(result.output_step_id, str):
return _construct_resume_workflow_from_step(
reader, result.output_step_id, objectref_cache=objectref_cache)
return _construct_resume_workflow_from_step(reader,
result.output_step_id)
# output does not exists or not valid. try to reconstruct it.
if not result.is_recoverable():
raise WorkflowStepNotRecoverableError(step_id)
input_workflows = []
instant_workflow_outputs: Dict[int, str] = {}
for i, _step_id in enumerate(result.workflows):
r = _construct_resume_workflow_from_step(
reader, _step_id, objectref_cache=objectref_cache)
if isinstance(r, Workflow):
input_workflows.append(r)
else:
input_workflows.append(None)
instant_workflow_outputs[i] = r
workflow_refs = list(map(WorkflowRef, result.workflow_refs))

# TODO (Alex): Refactor to remove this special case handling of object refs
resolved_object_refs = []
identifiers_to_await = []
promises_to_await = []

for identifier in result.object_refs:
if identifier not in objectref_cache:
paths = reader._key_step_args(identifier)
promise = reader._get(paths)
promises_to_await.append(promise)
identifiers_to_await.append(identifier)

loop = asyncio.get_event_loop()
object_refs_to_cache = loop.run_until_complete(
asyncio.gather(*promises_to_await))

for identifier, object_ref in zip(identifiers_to_await,
object_refs_to_cache):
objectref_cache[identifier] = object_ref

for identifier in result.object_refs:
resolved_object_refs.append(objectref_cache[identifier])

recovery_workflow: Workflow = _recover_workflow_step.options(
max_retries=result.max_retries,
catch_exceptions=result.catch_exceptions,
**result.ray_options).step(resolved_object_refs, input_workflows,
workflow_refs, instant_workflow_outputs)
recovery_workflow._step_id = step_id
recovery_workflow.data.step_type = result.step_type
return recovery_workflow

with serialization.objectref_cache():
input_workflows = []
for i, _step_id in enumerate(result.workflows):
r = _construct_resume_workflow_from_step(reader, _step_id)
if isinstance(r, Workflow):
input_workflows.append(r)
else:
assert isinstance(r, StepID)
# TODO (Alex): We should consider caching these outputs too.
input_workflows.append(reader.load_step_output(r))
workflow_refs = list(map(WorkflowRef, result.workflow_refs))

args, kwargs = reader.load_step_args(step_id, input_workflows,
workflow_refs)

recovery_workflow: Workflow = _recover_workflow_step.options(
max_retries=result.max_retries,
catch_exceptions=result.catch_exceptions,
**result.ray_options).step(args, kwargs, input_workflows,
workflow_refs)
recovery_workflow._step_id = step_id
recovery_workflow.data.step_type = result.step_type
return recovery_workflow


@ray.remote(num_returns=2)
Expand Down
158 changes: 139 additions & 19 deletions python/ray/workflow/serialization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import contextlib
from dataclasses import dataclass
import logging
import ray
from ray import cloudpickle
from ray.types import ObjectRef
from ray.workflow import common
from ray.workflow import storage
from ray.workflow import workflow_storage
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING

from collections import ChainMap
import io

if TYPE_CHECKING:
from ray.actor import ActorHandle
Expand Down Expand Up @@ -50,17 +54,6 @@ class Upload:
upload_task: ObjectRef[None]


@ray.remote(num_cpus=0)
def _put_helper(identifier: str, obj: Any,
wf_storage: "workflow_storage.WorkflowStorage") -> None:
if isinstance(obj, ray.ObjectRef):
raise NotImplementedError("Workflow does not support checkpointing "
"nested object references yet.")
paths = wf_storage._key_obj_id(identifier)
asyncio.get_event_loop().run_until_complete(
wf_storage._put(paths, obj, update=False))


@ray.remote(num_cpus=0)
class Manager:
"""
Expand Down Expand Up @@ -90,23 +83,150 @@ async def save_objectref(
A pair. The first element is the paths the ref will be uploaded to.
The second is an object reference to the upload task.
"""
wf_storage = workflow_storage.WorkflowStorage(workflow_id,
self._storage)
ref, = ref_tuple
# Use the hex as the key to avoid holding a reference to the object.
key = ref.hex()
key = (ref.hex(), workflow_id)

if key not in self._uploads:
# TODO(Alex): We should probably eventually free these refs.
identifier_ref = common.calculate_identifier.remote(ref)
upload_task = _put_helper.remote(identifier_ref, ref, wf_storage)
self._uploads[key] = Upload(identifier_ref, upload_task)
upload_task = _put_helper.remote(identifier_ref, ref, workflow_id,
self._storage)
self._uploads[key] = Upload(
identifier_ref=identifier_ref, upload_task=upload_task)
self._num_uploads += 1

info = self._uploads[key]
identifer = await info.identifier_ref
paths = wf_storage._key_obj_id(identifer)
paths = obj_id_to_paths(workflow_id, identifer)
return paths, info.upload_task

async def export_stats(self) -> Dict[str, Any]:
return {"num_uploads": self._num_uploads}


OBJECTS_DIR = "objects"


def obj_id_to_paths(workflow_id: str, object_id: str) -> List[str]:
return [workflow_id, OBJECTS_DIR, object_id]


@ray.remote(num_cpus=0)
def _put_helper(identifier: str, obj: Any, workflow_id: str,
storage: storage.Storage) -> None:
# TODO (Alex): This check isn't sufficient, it only works for directly
# nested object refs.
if isinstance(obj, ray.ObjectRef):
raise NotImplementedError("Workflow does not support checkpointing "
"nested object references yet.")
paths = obj_id_to_paths(workflow_id, identifier)
promise = dump_to_storage(
paths, obj, workflow_id, storage, update_existing=False)
return asyncio.get_event_loop().run_until_complete(promise)


def _reduce_objectref(workflow_id: str, storage: storage.Storage,
obj_ref: ObjectRef, tasks: List[ObjectRef]):
manager = get_or_create_manager()
paths, task = ray.get(
manager.save_objectref.remote((obj_ref, ), workflow_id))

assert task
tasks.append(task)

return _load_object_ref, (paths, storage)


async def dump_to_storage(paths: List[str],
obj: Any,
workflow_id: str,
storage: storage.Storage,
update_existing=True) -> None:
"""Serializes and puts arbitrary object, handling references. The object will
be uploaded at `paths`. Any object references will be uploaded to their
global, remote storage.

Args:
paths: The location to put the object.
obj: The object to serialize. If it contains object references, those
will be serialized too.
workflow_id: The workflow id.
storage: The storage to use. If obj contains object references,
`storage.put` will be called on them individually.
update_existing: If False, the object will not be uploaded if the path
exists.
"""
if not update_existing:
prefix = storage.make_key(*paths[:-1])
scan_result = await storage.scan_prefix(prefix)
if paths[-1] in scan_result:
return

tasks = []

# NOTE: Cloudpickle doesn't support private dispatch tables, so we extend
# the cloudpickler instead to avoid changing cloudpickle's global dispatch
# table which is shared with `ray.put`. See
# https://github.com/cloudpipe/cloudpickle/issues/437
class ObjectRefPickler(cloudpickle.CloudPickler):
_object_ref_reducer = {
ray.ObjectRef: lambda ref: _reduce_objectref(
workflow_id, storage, ref, tasks)
}
dispatch_table = ChainMap(_object_ref_reducer,
cloudpickle.CloudPickler.dispatch_table)
dispatch = dispatch_table

key = storage.make_key(*paths)

# TODO(Alex): We should be able to do this without the extra buffer.
with io.BytesIO() as f:
pickler = ObjectRefPickler(f)
pickler.dump(obj)
f.seek(0)
task = storage.put(key, f.read())
tasks.append(task)

await asyncio.gather(*tasks)


@ray.remote
def _load_ref_helper(key: str, storage: storage.Storage):
# TODO(Alex): We should stream the data directly into `cloudpickle.load`.
serialized = asyncio.get_event_loop().run_until_complete(storage.get(key))
return cloudpickle.loads(serialized)


# TODO (Alex): We should use weakrefs here instead requiring a context manager.
_object_cache: Optional[Dict[str, ray.ObjectRef]] = None


def _load_object_ref(paths: List[str],
storage: storage.Storage) -> ray.ObjectRef:
global _object_cache
key = storage.make_key(*paths)
if _object_cache is None:
return _load_ref_helper.remote(key, storage)

if _object_cache is None:
return _load_ref_helper.remote(key, storage)

if key not in _object_cache:
_object_cache[key] = _load_ref_helper.remote(key, storage)

return _object_cache[key]


@contextlib.contextmanager
def objectref_cache() -> Generator:
""" A reentrant caching context for object refs."""
global _object_cache
clear_cache = _object_cache is None
if clear_cache:
_object_cache = {}
try:
yield
finally:
if clear_cache:
_object_cache = None
Loading