Skip to content

Commit

Permalink
store thread_id and req_id in metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
ckw017 committed Sep 13, 2021
1 parent a9c9455 commit 12bc05c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 40 deletions.
18 changes: 15 additions & 3 deletions python/ray/tests/test_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import sys
import time
from typing import List, Tuple
from unittest.mock import patch

import grpc
Expand Down Expand Up @@ -32,6 +33,14 @@ def start_ray_and_proxy_manager(n_ports=2):
return pm, free_ports


def get_dummy_metadata(req_id: int) -> List[Tuple[str, str]]:
"""
Get mock request metadata for mutating RPC stubs to avoid caching logic.
"""
return [("client_id", "dummy_client_id"), ("thread_id", "dummy_thread_id"),
("req_id", str(req_id))]


@pytest.mark.skipif(
sys.platform == "win32",
reason="PSUtil does not work the same on windows.")
Expand Down Expand Up @@ -335,12 +344,14 @@ def make_internal_kv_calls():
# otherwise the SpecificServer will attempt to use the cached
# response from previous calls
response = task_servicer.KVPut(
ray_client_pb2.KVPutRequest(req_id=0, key=b"key", value=b"val"))
ray_client_pb2.KVPutRequest(key=b"key", value=b"val"),
metadata=get_dummy_metadata(0))
assert isinstance(response, ray_client_pb2.KVPutResponse)
assert not response.already_exists

response = task_servicer.KVPut(
ray_client_pb2.KVPutRequest(req_id=1, key=b"key", value=b"val2"))
ray_client_pb2.KVPutRequest(key=b"key", value=b"val2"),
metadata=get_dummy_metadata(1))
assert isinstance(response, ray_client_pb2.KVPutResponse)
assert response.already_exists

Expand All @@ -350,7 +361,8 @@ def make_internal_kv_calls():

response = task_servicer.KVPut(
ray_client_pb2.KVPutRequest(
req_id=2, key=b"key", value=b"val2", overwrite=True))
key=b"key", value=b"val2", overwrite=True),
metadata=get_dummy_metadata(2))
assert isinstance(response, ray_client_pb2.KVPutResponse)
assert response.already_exists

Expand Down
15 changes: 10 additions & 5 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import inspect
import json
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
CLIENT_SERVER_MAX_THREADS,
_get_client_id_from_context, ResponseCache)
CLIENT_SERVER_MAX_THREADS, ResponseCache)
from ray.util.client.server.proxier import serve_proxier
from ray.util.client.server.server_pickler import convert_from_arg
from ray.util.client.server.server_pickler import dumps_from_server
Expand All @@ -51,10 +50,16 @@ def _use_response_cache(func):

@functools.wraps(func)
def wrapper(self, request, context):
metadata = {k: v for k, v in context.invocation_metadata()}
expected_ids = ("client_id", "thread_id", "req_id")
if any(i not in metadata for i in expected_ids):
# Missing IDs, skip caching and call underlying stub directly
return func(request, context)

# Get relevant IDs to check cache
client_id = _get_client_id_from_context(context, logger)
thread_id = request.thread_id
req_id = request.req_id
client_id = metadata["client_id"]
thread_id = metadata["thread_id"]
req_id = metadata["req_id"]

# Check if response already cached
response_cache = self.response_caches[client_id]
Expand Down
33 changes: 17 additions & 16 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def _connect_channel(self, reconnecting=False) -> None:
# Ray is not ready yet, wait a timeout.
time.sleep(timeout)
# Fallthrough, backoff, and retry at the top of the loop
logger.warning("Waiting for Ray to become ready on the server, "
f"retry in {timeout}s...")
logger.info("Waiting for Ray to become ready on the server, "
f"retry in {timeout}s...")
if not reconnecting:
# Don't increase backoff when trying to reconnect --
# we already know the server exists, attempt to reconnect
Expand Down Expand Up @@ -274,21 +274,22 @@ def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
continue
raise ConnectionError("Client is shutting down.")

def _add_ids_to_request(self, request: Any):
def _add_ids_to_metadata(self, metadata: Any):
"""
Adds a unique req_id and the current thread's identifier to the
request. These values are useful for preventing mutating operations
metadata. These values are useful for preventing mutating operations
from being replayed on the server side in the event that the client
must retry a requsest.
Args:
request - A gRPC message to add the thread and request IDs to
"""
request.thread_id = threading.get_ident()
thread_id = str(threading.get_ident())
with self._req_id_lock:
self._req_id += 1
if self._req_id > INT32_MAX:
self._req_id = 1
request.req_id = self._req_id
req_id = str(self._req_id)
return metadata + [("thread_id", thread_id), ("req_id", req_id)]

def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
logger.debug(f"client gRPC channel state change: {conn_state}")
Expand Down Expand Up @@ -449,9 +450,9 @@ def _call_schedule_for_task(
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
logger.debug("Scheduling %s" % task)
task.client_id = self._client_id
self._add_ids_to_request(task)
metadata = self._add_ids_to_metadata(self.metadata)
try:
ticket = self._call_stub("Schedule", task, metadata=self.metadata)
ticket = self._call_stub("Schedule", task, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)
if not ticket.valid:
Expand Down Expand Up @@ -541,8 +542,8 @@ def terminate_actor(self, actor: ClientActorHandle,
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
self._add_ids_to_request(term)
self._call_stub("Terminate", term, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)

Expand All @@ -559,8 +560,8 @@ def terminate_task(self, obj: ClientObjectRef, force: bool,
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
self._add_ids_to_request(term)
self._call_stub("Terminate", term, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)

Expand Down Expand Up @@ -593,14 +594,14 @@ def internal_kv_put(self, key: bytes, value: bytes,
overwrite: bool) -> bool:
req = ray_client_pb2.KVPutRequest(
key=key, value=value, overwrite=overwrite)
self._add_ids_to_request(req)
resp = self._call_stub("KVPut", req, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
resp = self._call_stub("KVPut", req, metadata=metadata)
return resp.already_exists

def internal_kv_del(self, key: bytes) -> None:
req = ray_client_pb2.KVDelRequest(key=key)
self._add_ids_to_request(req)
self._call_stub("KVDel", req, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("KVDel", req, metadata=metadata)

def internal_kv_list(self, prefix: bytes) -> bytes:
req = ray_client_pb2.KVListRequest(prefix=prefix)
Expand Down
16 changes: 0 additions & 16 deletions src/ray/protobuf/ray_client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ message ClientTask {
TaskOptions options = 7;
// Options passed to create the default remote task excution environment.
TaskOptions baseline_options = 8;
// Thread that scheduled this task
int64 thread_id = 10;
// Identifier for this ticket
int32 req_id = 11;
}

message ClientTaskTicket {
Expand Down Expand Up @@ -212,10 +208,6 @@ message TerminateRequest {
ActorTerminate actor = 2;
TaskObjectTerminate task_object = 3;
}
// Thread that made this request
int64 thread_id = 4;
// Identifier for this request
int32 req_id = 5;
}

message TerminateResponse {
Expand All @@ -242,10 +234,6 @@ message KVPutRequest {
bytes key = 1;
bytes value = 2;
bool overwrite = 3;
// Thread that made this request
int64 thread_id = 4;
// Identifier for this request
int32 req_id = 5;
}

message KVPutResponse {
Expand All @@ -254,10 +242,6 @@ message KVPutResponse {

message KVDelRequest {
bytes key = 1;
// Thread that made this request
int64 thread_id = 2;
// Identifier for this request
int32 req_id = 3;
}

message KVDelResponse {
Expand Down

0 comments on commit 12bc05c

Please sign in to comment.