diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 8108f7019128d..e883cdfacd0b4 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -80,8 +80,6 @@ class StepType(str, Enum): class WorkflowInputs: # The object ref of the input arguments. args: ObjectRef - # The object refs in the arguments. - object_refs: List[ObjectRef] # TODO(suquark): maybe later we can replace it with WorkflowData. # The workflows in the arguments. workflows: "List[Workflow]" @@ -96,16 +94,6 @@ def _hash(obj: Any) -> bytes: return m.digest() -def calculate_identifiers(object_refs: List[ObjectRef]) -> List[str]: - """ - Calculate identifiers for an object ref based on the contents. (i.e. a hash - of the contents). - """ - hashes = ray.get([_hash.remote(obj) for obj in object_refs]) - encoded = map(base64.urlsafe_b64encode, hashes) - return [encoded_hash.decode("ascii") for encoded_hash in encoded] - - @ray.remote def calculate_identifier(obj: Any) -> str: """Calculate a url-safe identifier for an object.""" @@ -140,21 +128,11 @@ class WorkflowData: # name of the step name: str - # Cache the intended locations of object refs. These are expensive - # calculations since they require computing the hash over a large value. - _cached_refs: List[ObjectRef] = None - _cached_locs: List[str] = None - def to_metadata(self) -> Dict[str, Any]: - if self._cached_refs != self.inputs.object_refs: - self._cached_refs = self.inputs.object_refs - self._cached_locs = calculate_identifiers(self._cached_refs) - f = self.func_body metadata = { "name": get_module(f) + "." + get_qualname(f), "step_type": self.step_type, - "object_refs": self._cached_locs, "workflows": [w.step_id for w in self.inputs.workflows], "max_retries": self.max_retries, "workflow_refs": [wr.step_id for wr in self.inputs.workflow_refs], diff --git a/python/ray/workflow/recovery.py b/python/ray/workflow/recovery.py index 304689e427c47..58902b4419681 100644 --- a/python/ray/workflow/recovery.py +++ b/python/ray/workflow/recovery.py @@ -1,8 +1,8 @@ -import asyncio from typing import List, Any, Union, Dict, Callable, Tuple, Optional import ray from ray.workflow import workflow_context +from ray.workflow import serialization from ray.workflow.common import (Workflow, StepID, WorkflowRef, WorkflowExecutionResult) from ray.workflow import storage @@ -28,42 +28,31 @@ def __init__(self, workflow_id: str): @WorkflowStepFunction -def _recover_workflow_step(input_object_refs: List[ray.ObjectRef], +def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any], input_workflows: List[Any], - input_workflow_refs: List[WorkflowRef], - instant_workflow_inputs: Dict[int, StepID]): + input_workflow_refs: List[WorkflowRef]): """A workflow step that recovers the output of an unfinished step. Args: - input_object_refs: The object refs in the argument of - the (original) step. + args: The positional arguments for the step function. + kwargs: The keyword args for the step function. input_workflows: The workflows in the argument of the (original) step. They are resolved into physical objects (i.e. the output of the workflows) here. They come from other recover workflows we construct recursively. - instant_workflow_inputs: Same as 'input_workflows', but they come - point to workflow steps that have output checkpoints. They override - corresponding workflows in 'input_workflows'. Returns: The output of the recovered step. """ reader = workflow_storage.get_workflow_storage() - for index, _step_id in instant_workflow_inputs.items(): - # override input workflows with instant workflows - input_workflows[index] = reader.load_step_output(_step_id) - step_id = workflow_context.get_current_step_id() func: Callable = reader.load_step_func_body(step_id) - args, kwargs = reader.load_step_args( - step_id, input_workflows, input_object_refs, input_workflow_refs) return func(*args, **kwargs) def _construct_resume_workflow_from_step( reader: workflow_storage.WorkflowStorage, - step_id: StepID, - objectref_cache: Dict[str, Any] = None) -> Union[Workflow, StepID]: + step_id: StepID) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. @@ -76,61 +65,40 @@ def _construct_resume_workflow_from_step( A workflow that recovers the step, or a ID of a step that contains the output checkpoint file. """ - if objectref_cache is None: - objectref_cache = {} result: workflow_storage.StepInspectResult = reader.inspect_step(step_id) if result.output_object_valid: # we already have the output return step_id if isinstance(result.output_step_id, str): - return _construct_resume_workflow_from_step( - reader, result.output_step_id, objectref_cache=objectref_cache) + return _construct_resume_workflow_from_step(reader, + result.output_step_id) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) - input_workflows = [] - instant_workflow_outputs: Dict[int, str] = {} - for i, _step_id in enumerate(result.workflows): - r = _construct_resume_workflow_from_step( - reader, _step_id, objectref_cache=objectref_cache) - if isinstance(r, Workflow): - input_workflows.append(r) - else: - input_workflows.append(None) - instant_workflow_outputs[i] = r - workflow_refs = list(map(WorkflowRef, result.workflow_refs)) - - # TODO (Alex): Refactor to remove this special case handling of object refs - resolved_object_refs = [] - identifiers_to_await = [] - promises_to_await = [] - - for identifier in result.object_refs: - if identifier not in objectref_cache: - paths = reader._key_step_args(identifier) - promise = reader._get(paths) - promises_to_await.append(promise) - identifiers_to_await.append(identifier) - - loop = asyncio.get_event_loop() - object_refs_to_cache = loop.run_until_complete( - asyncio.gather(*promises_to_await)) - - for identifier, object_ref in zip(identifiers_to_await, - object_refs_to_cache): - objectref_cache[identifier] = object_ref - - for identifier in result.object_refs: - resolved_object_refs.append(objectref_cache[identifier]) - - recovery_workflow: Workflow = _recover_workflow_step.options( - max_retries=result.max_retries, - catch_exceptions=result.catch_exceptions, - **result.ray_options).step(resolved_object_refs, input_workflows, - workflow_refs, instant_workflow_outputs) - recovery_workflow._step_id = step_id - recovery_workflow.data.step_type = result.step_type - return recovery_workflow + + with serialization.objectref_cache(): + input_workflows = [] + for i, _step_id in enumerate(result.workflows): + r = _construct_resume_workflow_from_step(reader, _step_id) + if isinstance(r, Workflow): + input_workflows.append(r) + else: + assert isinstance(r, StepID) + # TODO (Alex): We should consider caching these outputs too. + input_workflows.append(reader.load_step_output(r)) + workflow_refs = list(map(WorkflowRef, result.workflow_refs)) + + args, kwargs = reader.load_step_args(step_id, input_workflows, + workflow_refs) + + recovery_workflow: Workflow = _recover_workflow_step.options( + max_retries=result.max_retries, + catch_exceptions=result.catch_exceptions, + **result.ray_options).step(args, kwargs, input_workflows, + workflow_refs) + recovery_workflow._step_id = step_id + recovery_workflow.data.step_type = result.step_type + return recovery_workflow @ray.remote(num_returns=2) diff --git a/python/ray/workflow/serialization.py b/python/ray/workflow/serialization.py index fe3c4c7aef9ce..92b1577dfe5f8 100644 --- a/python/ray/workflow/serialization.py +++ b/python/ray/workflow/serialization.py @@ -1,12 +1,16 @@ import asyncio +import contextlib from dataclasses import dataclass import logging import ray +from ray import cloudpickle from ray.types import ObjectRef from ray.workflow import common from ray.workflow import storage -from ray.workflow import workflow_storage -from typing import Any, Dict, List, Tuple, TYPE_CHECKING +from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING + +from collections import ChainMap +import io if TYPE_CHECKING: from ray.actor import ActorHandle @@ -50,17 +54,6 @@ class Upload: upload_task: ObjectRef[None] -@ray.remote(num_cpus=0) -def _put_helper(identifier: str, obj: Any, - wf_storage: "workflow_storage.WorkflowStorage") -> None: - if isinstance(obj, ray.ObjectRef): - raise NotImplementedError("Workflow does not support checkpointing " - "nested object references yet.") - paths = wf_storage._key_obj_id(identifier) - asyncio.get_event_loop().run_until_complete( - wf_storage._put(paths, obj, update=False)) - - @ray.remote(num_cpus=0) class Manager: """ @@ -90,23 +83,150 @@ async def save_objectref( A pair. The first element is the paths the ref will be uploaded to. The second is an object reference to the upload task. """ - wf_storage = workflow_storage.WorkflowStorage(workflow_id, - self._storage) ref, = ref_tuple # Use the hex as the key to avoid holding a reference to the object. - key = ref.hex() + key = (ref.hex(), workflow_id) if key not in self._uploads: # TODO(Alex): We should probably eventually free these refs. identifier_ref = common.calculate_identifier.remote(ref) - upload_task = _put_helper.remote(identifier_ref, ref, wf_storage) - self._uploads[key] = Upload(identifier_ref, upload_task) + upload_task = _put_helper.remote(identifier_ref, ref, workflow_id, + self._storage) + self._uploads[key] = Upload( + identifier_ref=identifier_ref, upload_task=upload_task) self._num_uploads += 1 info = self._uploads[key] identifer = await info.identifier_ref - paths = wf_storage._key_obj_id(identifer) + paths = obj_id_to_paths(workflow_id, identifer) return paths, info.upload_task async def export_stats(self) -> Dict[str, Any]: return {"num_uploads": self._num_uploads} + + +OBJECTS_DIR = "objects" + + +def obj_id_to_paths(workflow_id: str, object_id: str) -> List[str]: + return [workflow_id, OBJECTS_DIR, object_id] + + +@ray.remote(num_cpus=0) +def _put_helper(identifier: str, obj: Any, workflow_id: str, + storage: storage.Storage) -> None: + # TODO (Alex): This check isn't sufficient, it only works for directly + # nested object refs. + if isinstance(obj, ray.ObjectRef): + raise NotImplementedError("Workflow does not support checkpointing " + "nested object references yet.") + paths = obj_id_to_paths(workflow_id, identifier) + promise = dump_to_storage( + paths, obj, workflow_id, storage, update_existing=False) + return asyncio.get_event_loop().run_until_complete(promise) + + +def _reduce_objectref(workflow_id: str, storage: storage.Storage, + obj_ref: ObjectRef, tasks: List[ObjectRef]): + manager = get_or_create_manager() + paths, task = ray.get( + manager.save_objectref.remote((obj_ref, ), workflow_id)) + + assert task + tasks.append(task) + + return _load_object_ref, (paths, storage) + + +async def dump_to_storage(paths: List[str], + obj: Any, + workflow_id: str, + storage: storage.Storage, + update_existing=True) -> None: + """Serializes and puts arbitrary object, handling references. The object will + be uploaded at `paths`. Any object references will be uploaded to their + global, remote storage. + + Args: + paths: The location to put the object. + obj: The object to serialize. If it contains object references, those + will be serialized too. + workflow_id: The workflow id. + storage: The storage to use. If obj contains object references, + `storage.put` will be called on them individually. + update_existing: If False, the object will not be uploaded if the path + exists. + """ + if not update_existing: + prefix = storage.make_key(*paths[:-1]) + scan_result = await storage.scan_prefix(prefix) + if paths[-1] in scan_result: + return + + tasks = [] + + # NOTE: Cloudpickle doesn't support private dispatch tables, so we extend + # the cloudpickler instead to avoid changing cloudpickle's global dispatch + # table which is shared with `ray.put`. See + # https://github.com/cloudpipe/cloudpickle/issues/437 + class ObjectRefPickler(cloudpickle.CloudPickler): + _object_ref_reducer = { + ray.ObjectRef: lambda ref: _reduce_objectref( + workflow_id, storage, ref, tasks) + } + dispatch_table = ChainMap(_object_ref_reducer, + cloudpickle.CloudPickler.dispatch_table) + dispatch = dispatch_table + + key = storage.make_key(*paths) + + # TODO(Alex): We should be able to do this without the extra buffer. + with io.BytesIO() as f: + pickler = ObjectRefPickler(f) + pickler.dump(obj) + f.seek(0) + task = storage.put(key, f.read()) + tasks.append(task) + + await asyncio.gather(*tasks) + + +@ray.remote +def _load_ref_helper(key: str, storage: storage.Storage): + # TODO(Alex): We should stream the data directly into `cloudpickle.load`. + serialized = asyncio.get_event_loop().run_until_complete(storage.get(key)) + return cloudpickle.loads(serialized) + + +# TODO (Alex): We should use weakrefs here instead requiring a context manager. +_object_cache: Optional[Dict[str, ray.ObjectRef]] = None + + +def _load_object_ref(paths: List[str], + storage: storage.Storage) -> ray.ObjectRef: + global _object_cache + key = storage.make_key(*paths) + if _object_cache is None: + return _load_ref_helper.remote(key, storage) + + if _object_cache is None: + return _load_ref_helper.remote(key, storage) + + if key not in _object_cache: + _object_cache[key] = _load_ref_helper.remote(key, storage) + + return _object_cache[key] + + +@contextlib.contextmanager +def objectref_cache() -> Generator: + """ A reentrant caching context for object refs.""" + global _object_cache + clear_cache = _object_cache is None + if clear_cache: + _object_cache = {} + try: + yield + finally: + if clear_cache: + _object_cache = None diff --git a/python/ray/workflow/serialization_context.py b/python/ray/workflow/serialization_context.py index f2bb354a6c977..d442e888a7672 100644 --- a/python/ray/workflow/serialization_context.py +++ b/python/ray/workflow/serialization_context.py @@ -12,18 +12,13 @@ def _resolve_workflow_outputs(index: int) -> Any: raise ValueError("There is no context for resolving workflow outputs.") -def _resolve_objectrefs(index: int) -> ray.ObjectRef: - raise ValueError("There is no context for resolving object refs.") - - def _resolve_workflow_refs(index: int) -> Any: raise ValueError("There is no context for resolving workflow refs.") @contextlib.contextmanager def workflow_args_serialization_context( - workflows: List[Workflow], object_refs: List[ray.ObjectRef], - workflow_refs: List[WorkflowRef]) -> None: + workflows: List[Workflow], workflow_refs: List[WorkflowRef]) -> None: """ This serialization context reduces workflow input arguments to three parts: @@ -49,10 +44,9 @@ def workflow_args_serialization_context( Args: workflows: Workflow list output. - object_refs: ObjectRef list output. + workflow_refs: Workflow reference output list. """ workflow_deduplicator: Dict[Workflow, int] = {} - objectref_deduplicator: Dict[ray.ObjectRef, int] = {} workflowref_deduplicator: Dict[WorkflowRef, int] = {} def workflow_serializer(workflow): @@ -68,24 +62,6 @@ def workflow_serializer(workflow): serializer=workflow_serializer, deserializer=_resolve_workflow_outputs) - def objectref_serializer(obj_ref): - if obj_ref in objectref_deduplicator: - return objectref_deduplicator[obj_ref] - i = len(object_refs) - object_refs.append(obj_ref) - objectref_deduplicator[obj_ref] = i - return i - - # override the default ObjectRef serializer - # TODO(suquark): We are using Ray internal APIs to access serializers. - # This is only a workaround. We need alternatives later. - ray_objectref_reducer_backup = ray.cloudpickle.CloudPickler.dispatch[ - ray.ObjectRef] - register_serializer( - ray.ObjectRef, - serializer=objectref_serializer, - deserializer=_resolve_objectrefs) - def workflow_ref_serializer(workflow_ref): if workflow_ref in workflowref_deduplicator: return workflowref_deduplicator[workflow_ref] @@ -104,15 +80,11 @@ def workflow_ref_serializer(workflow_ref): finally: # we do not want to serialize Workflow objects in other places. deregister_serializer(Workflow) - # restore original dispatch - ray.cloudpickle.CloudPickler.dispatch[ - ray.ObjectRef] = ray_objectref_reducer_backup deregister_serializer(WorkflowRef) @contextlib.contextmanager def workflow_args_resolving_context(workflow_output_mapping: List[Any], - objectref_mapping: List[ray.ObjectRef], workflow_ref_mapping: List[Any]) -> None: """ This context resolves workflows and objectrefs inside workflow @@ -122,21 +94,18 @@ def workflow_args_resolving_context(workflow_output_mapping: List[Any], workflow_output_mapping: List of workflow outputs. objectref_mapping: List of object refs. """ - global _resolve_workflow_outputs, _resolve_objectrefs + global _resolve_workflow_outputs global _resolve_workflow_refs _resolve_workflow_outputs_bak = _resolve_workflow_outputs - _resolve_objectrefs_bak = _resolve_objectrefs _resolve_workflow_refs_bak = _resolve_workflow_refs _resolve_workflow_outputs = workflow_output_mapping.__getitem__ - _resolve_objectrefs = objectref_mapping.__getitem__ _resolve_workflow_refs = workflow_ref_mapping.__getitem__ try: yield finally: _resolve_workflow_outputs = _resolve_workflow_outputs_bak - _resolve_objectrefs = _resolve_objectrefs_bak _resolve_workflow_refs = _resolve_workflow_refs_bak @@ -148,14 +117,6 @@ def __reduce__(self): return _resolve_workflow_outputs, (self._index, ) -class _KeepObjectRefs: - def __init__(self, index: int): - self._index = index - - def __reduce__(self): - return _resolve_objectrefs, (self._index, ) - - class _KeepWorkflowRefs: def __init__(self, index: int): self._index = index @@ -167,51 +128,42 @@ def __reduce__(self): @contextlib.contextmanager def workflow_args_keeping_context() -> None: """ - This context only read workflow arguments. Workflows and objectrefs inside + This context only read workflow arguments. Workflows inside are untouched and can be serialized again properly. """ - global _resolve_workflow_outputs, _resolve_objectrefs + global _resolve_workflow_outputs global _resolve_workflow_refs _resolve_workflow_outputs_bak = _resolve_workflow_outputs - _resolve_objectrefs_bak = _resolve_objectrefs _resolve_workflow_refs_bak = _resolve_workflow_refs # we must capture the old functions to prevent self-referencing. def _keep_workflow_outputs(index: int): return _KeepWorkflowOutputs(index) - def _keep_objectrefs(index: int): - return _KeepObjectRefs(index) - def _keep_workflow_refs(index: int): return _KeepWorkflowRefs(index) _resolve_workflow_outputs = _keep_workflow_outputs - _resolve_objectrefs = _keep_objectrefs _resolve_workflow_refs = _keep_workflow_refs try: yield finally: _resolve_workflow_outputs = _resolve_workflow_outputs_bak - _resolve_objectrefs = _resolve_objectrefs_bak _resolve_workflow_refs = _resolve_workflow_refs_bak def make_workflow_inputs(args_list: List[Any]) -> WorkflowInputs: workflows: List[Workflow] = [] - object_refs: List[ray.ObjectRef] = [] workflow_refs: List[WorkflowRef] = [] - with workflow_args_serialization_context(workflows, object_refs, - workflow_refs): + with workflow_args_serialization_context(workflows, workflow_refs): # NOTE: When calling 'ray.put', we trigger python object # serialization. Under our serialization context, - # Workflows and ObjectRefs are separated from the arguments, + # Workflows are separated from the arguments, # leaving a placeholder object with all other python objects. # Then we put the placeholder object to object store, # so it won't be mutated later. This guarantees correct # semantics. See "tests/test_variable_mutable.py" as # an example. input_placeholder: ray.ObjectRef = ray.put(args_list) - return WorkflowInputs(input_placeholder, object_refs, workflows, - workflow_refs) + return WorkflowInputs(input_placeholder, workflows, workflow_refs) diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index 4f2a6137da2fd..878c7b40bf451 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -1,3 +1,4 @@ +import asyncio from dataclasses import dataclass import logging from typing import (List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, @@ -9,15 +10,23 @@ from ray.workflow import workflow_context from ray.workflow import recovery from ray.workflow.workflow_context import get_step_status_info +from ray.workflow import serialization from ray.workflow import serialization_context from ray.workflow import workflow_storage from ray.workflow.workflow_access import (get_or_create_management_actor, get_management_actor) -from ray.workflow.common import (Workflow, WorkflowStatus, WorkflowOutputType, - WorkflowExecutionResult, StepType) +from ray.workflow.common import ( + Workflow, + WorkflowStatus, + WorkflowOutputType, + WorkflowExecutionResult, + StepType, + StepID, + WorkflowData, +) if TYPE_CHECKING: - from ray.workflow.common import (StepID, WorkflowRef, WorkflowInputs) + from ray.workflow.common import (WorkflowRef, WorkflowInputs) StepInputTupleToResolve = Tuple[ObjectRef, List[ObjectRef], List[ObjectRef]] @@ -114,7 +123,7 @@ def _resolve_step_inputs( step_inputs.workflow_refs) with serialization_context.workflow_args_resolving_context( - objects_mapping, step_inputs.object_refs, workflow_ref_mapping): + objects_mapping, workflow_ref_mapping): # reconstruct input arguments under correct serialization context flattened_args: List[Any] = ray.get(step_inputs.args) @@ -149,8 +158,7 @@ def execute_workflow( Args: workflow: The workflow to be executed. outer_most_step_id: The ID of the outer most workflow. None if it - does not exists. See "step_executor.execute_workflow" for detailed - explanation. + does not exists. last_step_of_workflow: The step that generates the output of the workflow (including nested steps). Returns: @@ -181,6 +189,30 @@ def execute_workflow( return result +async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, + step_id: StepID, inputs: WorkflowData) -> None: + """Save workflow inputs.""" + metadata = inputs.to_metadata() + with serialization_context.workflow_args_keeping_context(): + # TODO(suquark): in the future we should write to storage directly + # with plasma store object in memory. + args_obj = ray.get(inputs.inputs.args) + + workflow_id = wf_storage._workflow_id + storage = wf_storage._storage + save_tasks = [ + # TODO (Alex): Handle the json case better? + wf_storage._put( + wf_storage._key_step_input_metadata(step_id), metadata, True), + serialization.dump_to_storage( + wf_storage._key_step_function_body(step_id), inputs.func_body, + workflow_id, storage), + serialization.dump_to_storage( + wf_storage._key_step_args(step_id), args_obj, workflow_id, storage) + ] + await asyncio.gather(*save_tasks) + + def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID", ret: Union["Workflow", Any], @@ -197,7 +229,13 @@ def commit_step(store: workflow_storage.WorkflowStorage, """ from ray.workflow.common import Workflow if isinstance(ret, Workflow): - store.save_subworkflow(ret) + assert not ret.executed + tasks = [ + _write_step_inputs(store, w.step_id, w.data) + for w in ret._iter_workflows_in_dag() + ] + asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) + store.save_step_output(step_id, ret, exception, outer_most_step_id) @@ -373,7 +411,6 @@ class _BakedWorkflowInputs: and their outputs (ObjectRefs) replace the original workflows.""" args: "ObjectRef" workflow_outputs: "List[ObjectRef]" - object_refs: "List[ObjectRef]" workflow_refs: "List[WorkflowRef]" @classmethod @@ -381,12 +418,11 @@ def from_workflow_inputs(cls, inputs: "WorkflowInputs"): workflow_outputs = [ execute_workflow(w).persisted_output for w in inputs.workflows ] - return cls(inputs.args, workflow_outputs, inputs.object_refs, - inputs.workflow_refs) + return cls(inputs.args, workflow_outputs, inputs.workflow_refs) def __reduce__(self): return _BakedWorkflowInputs, (self.args, self.workflow_outputs, - self.object_refs, self.workflow_refs) + self.workflow_refs) def _record_step_status(step_id: "StepID", diff --git a/python/ray/workflow/tests/test_serialization.py b/python/ray/workflow/tests/test_serialization.py index dfe96564f8426..7ea7ef35d3d6f 100644 --- a/python/ray/workflow/tests/test_serialization.py +++ b/python/ray/workflow/tests/test_serialization.py @@ -3,6 +3,8 @@ import ray from ray import workflow from ray.workflow import serialization +from ray.workflow import storage +from ray.workflow import workflow_storage from ray._private.test_utils import run_string_as_driver_nonblocking from ray.tests.conftest import * # noqa import subprocess @@ -82,6 +84,26 @@ def test_dedupe_serialization_2(workflow_start_regular_shared): assert get_num_uploads() == 2 +def test_same_object_many_workflows(workflow_start_regular_shared): + """Ensure that when we dedupe uploads, we upload the object once per workflow, + since different workflows shouldn't look in each others object directories. + """ + + @ray.workflow.step + def f(a): + return [a[0]] + + x = {0: ray.put(10)} + + result1 = f.step(x).run() + result2 = f.step(x).run() + print(result1) + print(result2) + + assert ray.get(*result1) == 10 + assert ray.get(*result2) == 10 + + def test_dedupe_cluster_failure(reset_workflow, tmp_path): ray.shutdown() """ @@ -143,6 +165,36 @@ def foo(objrefs): ray.shutdown() +def test_embedded_objectrefs(workflow_start_regular): + workflow_id = test_embedded_objectrefs.__name__ + base_storage = storage.get_global_storage() + + class ObjectRefsWrapper: + def __init__(self, refs): + self.refs = refs + + url = base_storage.storage_url + + wrapped = ObjectRefsWrapper([ray.put(1), ray.put(2)]) + + promise = serialization.dump_to_storage(["key"], wrapped, workflow_id, + base_storage) + workflow_storage.asyncio_run(promise) + + # Be extremely explicit about shutting down. We want to make sure the + # `_get` call deserializes the full object and puts it in the object store. + # Shutting down the cluster should guarantee we don't accidently get the + # old object and pass the test. + ray.shutdown() + subprocess.check_output("ray stop --force", shell=True) + + workflow.init(url) + storage2 = workflow_storage.get_workflow_storage(workflow_id) + + result = workflow_storage.asyncio_run(storage2._get(["key"])) + assert ray.get(result.refs) == [1, 2] + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/tests/test_storage.py b/python/ray/workflow/tests/test_storage.py index c426fb5c79060..419fc6edfd01f 100644 --- a/python/ray/workflow/tests/test_storage.py +++ b/python/ray/workflow/tests/test_storage.py @@ -2,13 +2,10 @@ import ray from ray._private import signature from ray.tests.conftest import * # noqa -from ray import workflow from ray.workflow import workflow_storage from ray.workflow import storage -from ray.workflow.workflow_storage import asyncio_run, \ - get_workflow_storage +from ray.workflow.workflow_storage import asyncio_run from ray.workflow.common import StepType -import subprocess def some_func(x): @@ -47,7 +44,6 @@ def test_workflow_storage(workflow_start_regular): input_metadata = { "name": "test_basic_workflows.append1", "step_type": StepType.FUNCTION, - "object_refs": ["abc"], "workflows": ["def"], "workflow_refs": ["some_ref"], "max_retries": 1, @@ -87,7 +83,7 @@ def test_workflow_storage(workflow_start_regular): asyncio_run(wf_storage._put(wf_storage._key_step_output(step_id), output)) assert wf_storage.load_step_output(step_id) == output - assert wf_storage.load_step_args(step_id, [], [], []) == args + assert wf_storage.load_step_args(step_id, [], []) == args assert wf_storage.load_step_func_body(step_id)(33) == 34 assert ray.get(wf_storage.load_object_ref( obj_ref.hex())) == object_resolved @@ -131,7 +127,6 @@ def test_workflow_storage(workflow_start_regular): step_type=StepType.FUNCTION, args_valid=True, func_body_valid=True, - object_refs=input_metadata["object_refs"], workflows=input_metadata["workflows"], workflow_refs=input_metadata["workflow_refs"], ray_options={}) @@ -149,7 +144,6 @@ def test_workflow_storage(workflow_start_regular): assert inspect_result == workflow_storage.StepInspectResult( step_type=StepType.FUNCTION, func_body_valid=True, - object_refs=input_metadata["object_refs"], workflows=input_metadata["workflows"], workflow_refs=input_metadata["workflow_refs"], ray_options={}) @@ -163,7 +157,6 @@ def test_workflow_storage(workflow_start_regular): inspect_result = wf_storage.inspect_step(step_id) assert inspect_result == workflow_storage.StepInspectResult( step_type=StepType.FUNCTION, - object_refs=input_metadata["object_refs"], workflows=input_metadata["workflows"], workflow_refs=input_metadata["workflow_refs"], ray_options={}) @@ -176,35 +169,6 @@ def test_workflow_storage(workflow_start_regular): assert not inspect_result.is_recoverable() -def test_embedded_objectrefs(workflow_start_regular): - workflow_id = test_workflow_storage.__name__ - - class ObjectRefsWrapper: - def __init__(self, refs): - self.refs = refs - - wf_storage = workflow_storage.WorkflowStorage(workflow_id, - storage.get_global_storage()) - url = storage.get_global_storage().storage_url - - wrapped = ObjectRefsWrapper([ray.put(1), ray.put(2)]) - - asyncio_run(wf_storage._put(["key"], wrapped)) - - # Be extremely explicit about shutting down. We want to make sure the - # `_get` call deserializes the full object and puts it in the object store. - # Shutting down the cluster should guarantee we don't accidently get the - # old object and pass the test. - ray.shutdown() - subprocess.check_output("ray stop --force", shell=True) - - workflow.init(url) - storage2 = get_workflow_storage(workflow_id) - - result = asyncio_run(storage2._get(["key"])) - assert ray.get(result.refs) == [1, 2] - - if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index b58eaab634f79..0524637cf08da 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -115,7 +115,7 @@ class LatestWorkflowOutput: # TODO(suquark): we may use an actor pool in the future if too much # concurrent workflow access blocks the actor. -@ray.remote +@ray.remote(num_cpus=0) class WorkflowManagementActor: """Keep the ownership and manage the workflow output.""" @@ -179,7 +179,8 @@ def run_or_resume(self, workflow_id: str, ignore_existing: bool = False latest_output = LatestWorkflowOutput(result.persisted_output, workflow_id, step_id) self._workflow_outputs[workflow_id] = latest_output - print("run_or_resume: ", workflow_id, step_id, result.persisted_output) + logger.info(f"run_or_resume: {workflow_id}, {step_id}," + f"{result.persisted_output}") self._step_output_cache[(workflow_id, step_id)] = latest_output wf_store.save_workflow_meta( @@ -225,7 +226,6 @@ def update_step_status(self, workflow_id: str, step_id: str, common.WorkflowMetaData(common.WorkflowStatus.FAILED)) self._step_status.pop(workflow_id) else: - # remaining = 0 wf_store.save_workflow_meta( common.WorkflowMetaData(common.WorkflowStatus.SUCCESSFUL)) self._step_status.pop(workflow_id) @@ -372,10 +372,12 @@ def init_management_actor() -> None: except ValueError: logger.info("Initializing workflow manager...") # the actor does not exist - WorkflowManagementActor.options( + actor = WorkflowManagementActor.options( name=common.MANAGEMENT_ACTOR_NAME, namespace=common.MANAGEMENT_ACTOR_NAMESPACE, lifetime="detached").remote(store) + # No-op to ensure the actor is created before the driver exits. + ray.get(actor.get_storage_url.remote()) def get_management_actor() -> "ActorHandle": diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index aea72428e56de..bf18f471483de 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -4,20 +4,18 @@ """ import asyncio -from collections import ChainMap from typing import Dict, List, Optional, Any, Callable, Tuple, Union from dataclasses import dataclass -import io import logging import ray from ray import cloudpickle from ray._private import signature from ray.workflow import storage -from ray.workflow.common import (Workflow, WorkflowData, StepID, - WorkflowMetaData, WorkflowStatus, WorkflowRef, - StepType) +from ray.workflow.common import (Workflow, StepID, WorkflowMetaData, + WorkflowStatus, WorkflowRef, StepType) from ray.workflow import workflow_context +from ray.workflow import serialization from ray.workflow import serialization_context from ray.workflow.storage import (DataLoadError, DataSaveError, KeyNotFoundError) @@ -63,8 +61,6 @@ class StepInspectResult: args_valid: bool = False # The step function body checkpoint exists and valid. func_body_valid: bool = False - # The object refs in the inputs of the workflow. - object_refs: Optional[List[str]] = None # The workflows in the inputs of the workflow. workflows: Optional[List[str]] = None # The dynamically referenced workflows in the input of the workflow. @@ -82,8 +78,8 @@ class StepInspectResult: def is_recoverable(self) -> bool: return (self.output_object_valid or self.output_step_id - or (self.args_valid and self.object_refs is not None - and self.workflows is not None and self.func_body_valid)) + or (self.args_valid and self.workflows is not None + and self.func_body_valid)) class WorkflowStorage: @@ -151,7 +147,11 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], if exception is None: # This workflow step returns a object. ret = ray.get(ret) if isinstance(ret, ray.ObjectRef) else ret - tasks.append(self._put(self._key_step_output(step_id), ret)) + promise = serialization.dump_to_storage( + self._key_step_output(step_id), ret, self._workflow_id, + self._storage) + tasks.append(promise) + # tasks.append(self._put(self._key_step_output(step_id), ret)) dynamic_output_id = step_id # TODO (yic): Delete exception file @@ -164,8 +164,12 @@ def save_step_output(self, step_id: StepID, ret: Union[Workflow, Any], dynamic_output_id)) else: assert ret is None - tasks.append( - self._put(self._key_step_exception(step_id), exception)) + promise = serialization.dump_to_storage( + self._key_step_exception(step_id), exception, + self._workflow_id, self._storage) + tasks.append(promise) + # tasks.append( + # self._put(self._key_step_exception(step_id), exception)) asyncio_run(asyncio.gather(*tasks)) @@ -195,7 +199,6 @@ async def _gen_step_id(): def load_step_args( self, step_id: StepID, workflows: List[Any], - object_refs: List[ray.ObjectRef], workflow_refs: List[WorkflowRef]) -> Tuple[List, Dict[str, Any]]: """Load the input arguments of the workflow step. This must be done under a serialization context, otherwise the arguments would @@ -211,7 +214,7 @@ def load_step_args( Args and kwargs. """ with serialization_context.workflow_args_resolving_context( - workflows, object_refs, workflow_refs): + workflows, workflow_refs): flattened_args = asyncio_run( self._get(self._key_step_args(step_id))) # dereference arguments like Ray remote functions @@ -332,7 +335,6 @@ async def _inspect_step(self, step_id: StepID) -> StepInspectResult: return StepInspectResult( args_valid=(STEP_ARGS in keys), func_body_valid=(STEP_FUNC_BODY in keys), - object_refs=metadata["object_refs"], workflows=metadata["workflows"], workflow_refs=metadata["workflow_refs"], max_retries=metadata.get("max_retries"), @@ -349,46 +351,9 @@ async def _inspect_step(self, step_id: StepID) -> StepInspectResult: ) async def _save_object_ref(self, identifier: str, obj_ref: ray.ObjectRef): - # TODO (Alex): We should do this in a remote task to exploit locality. data = await obj_ref await self._put(self._key_obj_id(identifier), data) - async def _write_step_inputs(self, step_id: StepID, - inputs: WorkflowData) -> None: - """Save workflow inputs.""" - metadata = inputs.to_metadata() - - with serialization_context.workflow_args_keeping_context(): - # TODO(suquark): in the future we should write to storage directly - # with plasma store object in memory. - args_obj = ray.get(inputs.inputs.args) - save_tasks = [ - self._put(self._key_step_input_metadata(step_id), metadata, True), - self._put(self._key_step_function_body(step_id), inputs.func_body), - self._put(self._key_step_args(step_id), args_obj) - ] - - for identifier, obj_ref in zip(metadata["object_refs"], - inputs.inputs.object_refs): - paths = self._key_step_args(identifier) - save_tasks.append(self._put(paths, obj_ref)) - - await asyncio.gather(*save_tasks) - - def save_subworkflow(self, workflow: Workflow) -> None: - """Save the DAG and inputs of the sub-workflow. - - Args: - workflow: A sub-workflow. Could be a nested workflow inside - a workflow step. - """ - assert not workflow.executed - tasks = [ - self._write_step_inputs(w.step_id, w.data) - for w in workflow._iter_workflows_in_dag() - ] - asyncio_run(asyncio.gather(*tasks)) - def load_actor_class_body(self) -> type: """Load the class body of the virtual actor. @@ -479,23 +444,8 @@ def get_latest_progress(self) -> "StepID": return asyncio_run(self._get(self._key_workflow_progress(), True))["step_id"] - def _reduce_objectref(self, obj_ref: ObjectRef, - upload_tasks: List[ObjectRef]): - from ray.workflow import serialization - manager = serialization.get_or_create_manager() - paths, task = ray.get( - manager.save_objectref.remote((obj_ref, ), self._workflow_id)) - - assert task - upload_tasks.append(task) - - return _load_object_ref, (paths, self) - - async def _put(self, - paths: List[str], - data: Any, - is_json: bool = False, - update: bool = True) -> str: + async def _put(self, paths: List[str], data: Any, + is_json: bool = False) -> str: """ Serialize and put an object in the object store. @@ -506,46 +456,20 @@ async def _put(self, update: If false, do not upload data when the path already exists. """ key = self._storage.make_key(*paths) - if not update: - prefix = self._storage.make_key(*paths[:-1]) - scan_result = await self._storage.scan_prefix(prefix) - if paths[-1] in scan_result: - return key try: upload_tasks: List[ObjectRef] = [] if not is_json: - # Setup our custom serializer. - output_buffer = io.BytesIO() - - # Cloudpickle doesn't support private dispatch tables, so we - # extend the cloudpickler instead to avoid changing - # cloudpickle's global dispatch table which is shared with - # `ray.put`. See - # https://github.com/cloudpipe/cloudpickle/issues/437 - class ObjectRefPickler(cloudpickle.CloudPickler): - _object_ref_reducer = { - ray.ObjectRef: lambda ref: self._reduce_objectref( - ref, upload_tasks) - } - dispatch_table = ChainMap( - _object_ref_reducer, - cloudpickle.CloudPickler.dispatch_table) - dispatch = dispatch_table - - pickler = ObjectRefPickler(output_buffer) - pickler.dump(data) - output_buffer.seek(0) - value = output_buffer.read() + await serialization.dump_to_storage( + paths, data, self._workflow_id, self._storage) else: value = data - - await self._storage.put(key, value, is_json=is_json) - # The serializer only kicks off the upload tasks, and returns - # the location they will be uploaded to in order to allow those - # uploads to be parallelized. We should wait for those uploads - # to be finished before we consider the object fully - # serialized. - await asyncio.gather(*upload_tasks) + outer_coro = self._storage.put(key, value, is_json=is_json) + # The serializer only kicks off the upload tasks, and returns + # the location they will be uploaded to in order to allow those + # uploads to be parallelized. We should wait for those uploads + # to be finished before we consider the object fully + # serialized. + await asyncio.gather(outer_coro, *upload_tasks) except Exception as e: raise DataSaveError from e