diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 1923d5e3f3125..345b124be926c 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -5,6 +5,7 @@ import random import sys import time +from typing import List, Tuple from unittest.mock import patch import grpc @@ -32,6 +33,14 @@ def start_ray_and_proxy_manager(n_ports=2): return pm, free_ports +def get_dummy_metadata(req_id: int) -> List[Tuple[str, str]]: + """ + Get mock request metadata for mutating RPC stubs to avoid caching logic. + """ + return [("client_id", "dummy_client_id"), ("thread_id", "dummy_thread_id"), + ("req_id", str(req_id))] + + @pytest.mark.skipif( sys.platform == "win32", reason="PSUtil does not work the same on windows.") @@ -335,12 +344,14 @@ def make_internal_kv_calls(): # otherwise the SpecificServer will attempt to use the cached # response from previous calls response = task_servicer.KVPut( - ray_client_pb2.KVPutRequest(req_id=0, key=b"key", value=b"val")) + ray_client_pb2.KVPutRequest(key=b"key", value=b"val"), + metadata=get_dummy_metadata(0)) assert isinstance(response, ray_client_pb2.KVPutResponse) assert not response.already_exists response = task_servicer.KVPut( - ray_client_pb2.KVPutRequest(req_id=1, key=b"key", value=b"val2")) + ray_client_pb2.KVPutRequest(key=b"key", value=b"val2"), + metadata=get_dummy_metadata(1)) assert isinstance(response, ray_client_pb2.KVPutResponse) assert response.already_exists @@ -350,7 +361,8 @@ def make_internal_kv_calls(): response = task_servicer.KVPut( ray_client_pb2.KVPutRequest( - req_id=2, key=b"key", value=b"val2", overwrite=True)) + key=b"key", value=b"val2", overwrite=True), + metadata=get_dummy_metadata(2)) assert isinstance(response, ray_client_pb2.KVPutResponse) assert response.already_exists diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 56eb49b473591..4659732d7a8af 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -23,8 +23,7 @@ import inspect import json from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS, - CLIENT_SERVER_MAX_THREADS, - _get_client_id_from_context, ResponseCache) + CLIENT_SERVER_MAX_THREADS, ResponseCache) from ray.util.client.server.proxier import serve_proxier from ray.util.client.server.server_pickler import convert_from_arg from ray.util.client.server.server_pickler import dumps_from_server @@ -51,10 +50,16 @@ def _use_response_cache(func): @functools.wraps(func) def wrapper(self, request, context): + metadata = {k: v for k, v in context.invocation_metadata()} + expected_ids = ("client_id", "thread_id", "req_id") + if any(i not in metadata for i in expected_ids): + # Missing IDs, skip caching and call underlying stub directly + return func(request, context) + # Get relevant IDs to check cache - client_id = _get_client_id_from_context(context, logger) - thread_id = request.thread_id - req_id = request.req_id + client_id = metadata["client_id"] + thread_id = metadata["thread_id"] + req_id = metadata["req_id"] # Check if response already cached response_cache = self.response_caches[client_id] diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 988879bf7fa15..f453d76180ec0 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -204,8 +204,8 @@ def _connect_channel(self, reconnecting=False) -> None: # Ray is not ready yet, wait a timeout. time.sleep(timeout) # Fallthrough, backoff, and retry at the top of the loop - logger.warning("Waiting for Ray to become ready on the server, " - f"retry in {timeout}s...") + logger.info("Waiting for Ray to become ready on the server, " + f"retry in {timeout}s...") if not reconnecting: # Don't increase backoff when trying to reconnect -- # we already know the server exists, attempt to reconnect @@ -274,21 +274,22 @@ def _call_stub(self, stub_name: str, *args, **kwargs) -> Any: continue raise ConnectionError("Client is shutting down.") - def _add_ids_to_request(self, request: Any): + def _add_ids_to_metadata(self, metadata: Any): """ Adds a unique req_id and the current thread's identifier to the - request. These values are useful for preventing mutating operations + metadata. These values are useful for preventing mutating operations from being replayed on the server side in the event that the client must retry a requsest. Args: request - A gRPC message to add the thread and request IDs to """ - request.thread_id = threading.get_ident() + thread_id = str(threading.get_ident()) with self._req_id_lock: self._req_id += 1 if self._req_id > INT32_MAX: self._req_id = 1 - request.req_id = self._req_id + req_id = str(self._req_id) + return metadata + [("thread_id", thread_id), ("req_id", req_id)] def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity): logger.debug(f"client gRPC channel state change: {conn_state}") @@ -449,9 +450,9 @@ def _call_schedule_for_task( self, task: ray_client_pb2.ClientTask) -> List[bytes]: logger.debug("Scheduling %s" % task) task.client_id = self._client_id - self._add_ids_to_request(task) + metadata = self._add_ids_to_metadata(self.metadata) try: - ticket = self._call_stub("Schedule", task, metadata=self.metadata) + ticket = self._call_stub("Schedule", task, metadata=metadata) except grpc.RpcError as e: raise decode_exception(e) if not ticket.valid: @@ -541,8 +542,8 @@ def terminate_actor(self, actor: ClientActorHandle, try: term = ray_client_pb2.TerminateRequest(actor=term_actor) term.client_id = self._client_id - self._add_ids_to_request(term) - self._call_stub("Terminate", term, metadata=self.metadata) + metadata = self._add_ids_to_metadata(self.metadata) + self._call_stub("Terminate", term, metadata=metadata) except grpc.RpcError as e: raise decode_exception(e) @@ -559,8 +560,8 @@ def terminate_task(self, obj: ClientObjectRef, force: bool, try: term = ray_client_pb2.TerminateRequest(task_object=term_object) term.client_id = self._client_id - self._add_ids_to_request(term) - self._call_stub("Terminate", term, metadata=self.metadata) + metadata = self._add_ids_to_metadata(self.metadata) + self._call_stub("Terminate", term, metadata=metadata) except grpc.RpcError as e: raise decode_exception(e) @@ -593,14 +594,14 @@ def internal_kv_put(self, key: bytes, value: bytes, overwrite: bool) -> bool: req = ray_client_pb2.KVPutRequest( key=key, value=value, overwrite=overwrite) - self._add_ids_to_request(req) - resp = self._call_stub("KVPut", req, metadata=self.metadata) + metadata = self._add_ids_to_metadata(self.metadata) + resp = self._call_stub("KVPut", req, metadata=metadata) return resp.already_exists def internal_kv_del(self, key: bytes) -> None: req = ray_client_pb2.KVDelRequest(key=key) - self._add_ids_to_request(req) - self._call_stub("KVDel", req, metadata=self.metadata) + metadata = self._add_ids_to_metadata(self.metadata) + self._call_stub("KVDel", req, metadata=metadata) def internal_kv_list(self, prefix: bytes) -> bytes: req = ray_client_pb2.KVListRequest(prefix=prefix) diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index c8e0507d52280..9ccbd2ac7e51f 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -76,10 +76,6 @@ message ClientTask { TaskOptions options = 7; // Options passed to create the default remote task excution environment. TaskOptions baseline_options = 8; - // Thread that scheduled this task - int64 thread_id = 10; - // Identifier for this ticket - int32 req_id = 11; } message ClientTaskTicket { @@ -212,10 +208,6 @@ message TerminateRequest { ActorTerminate actor = 2; TaskObjectTerminate task_object = 3; } - // Thread that made this request - int64 thread_id = 4; - // Identifier for this request - int32 req_id = 5; } message TerminateResponse { @@ -242,10 +234,6 @@ message KVPutRequest { bytes key = 1; bytes value = 2; bool overwrite = 3; - // Thread that made this request - int64 thread_id = 4; - // Identifier for this request - int32 req_id = 5; } message KVPutResponse { @@ -254,10 +242,6 @@ message KVPutResponse { message KVDelRequest { bytes key = 1; - // Thread that made this request - int64 thread_id = 2; - // Identifier for this request - int32 req_id = 3; } message KVDelResponse {