Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Use pickle.dumps for proxy->replica messages #49539

Merged
merged 18 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import pickle
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Awaitable, Callable, List, Optional

from starlette.types import Scope

import ray
from ray.actor import ActorHandle
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME, SERVE_NAMESPACE
from ray.serve.generated.serve_pb2 import DeploymentStatus as DeploymentStatusProto
from ray.serve.generated.serve_pb2 import (
DeploymentStatusInfo as DeploymentStatusInfoProto,
Expand Down Expand Up @@ -623,33 +623,50 @@ def is_grpc_request(self) -> bool:
return self._request_protocol == RequestProtocol.GRPC


@dataclass
class StreamingHTTPRequest:
"""Sent from the HTTP proxy to replicas on the streaming codepath."""

asgi_scope: Scope
# Takes request metadata, returns a pickled list of ASGI messages.
receive_asgi_messages: Callable[[RequestMetadata], Awaitable[bytes]]
def __init__(
self,
asgi_scope: Scope,
*,
proxy_actor_name: Optional[str] = None,
receive_asgi_messages: Optional[
Callable[[RequestMetadata], Awaitable[bytes]]
] = None,
):
self._asgi_scope: Scope = asgi_scope

def __getstate__(self) -> Dict[str, Any]:
"""Custom serializer to use vanilla `pickle` for the ASGI scope.
if proxy_actor_name is None and receive_asgi_messages is None:
raise ValueError(
"Either proxy_actor_name or receive_asgi_messages must be provided."
)

This is possible because we know the scope is a dictionary containing
only Python primitive types. Vanilla `pickle` is much faster than cloudpickle.
"""
return {
"pickled_asgi_scope": pickle.dumps(self.asgi_scope),
"receive_asgi_messages": self.receive_asgi_messages,
}
# If receive_asgi_messages is passed, it'll be called directly.
# If proxy_actor_name is passed, the actor will be fetched and its
# `receive_asgi_messages` method will be called.
self._proxy_actor_name: Optional[str] = proxy_actor_name
# Need to keep the actor handle cached to avoid "lost reference to actor" error.
self._cached_proxy_actor: Optional[ActorHandle] = None
self._receive_asgi_messages: Optional[
Callable[[RequestMetadata], Awaitable[bytes]]
] = receive_asgi_messages

def __setstate__(self, state: Dict[str, Any]):
"""Custom deserializer to use vanilla `pickle` for the ASGI scope.
@property
def asgi_scope(self) -> Scope:
return self._asgi_scope

This is possible because we know the scope is a dictionary containing
only Python primitive types. Vanilla `pickle` is much faster than cloudpickle.
"""
self.asgi_scope = pickle.loads(state["pickled_asgi_scope"])
self.receive_asgi_messages = state["receive_asgi_messages"]
@property
def receive_asgi_messages(self) -> Callable[[RequestMetadata], Awaitable[bytes]]:
if self._receive_asgi_messages is None:
self._cached_proxy_actor = ray.get_actor(
self._proxy_actor_name, namespace=SERVE_NAMESPACE
)
self._receive_asgi_messages = (
self._cached_proxy_actor.receive_asgi_messages.remote
)

return self._receive_asgi_messages


class TargetCapacityDirection(str, Enum):
Expand Down
27 changes: 18 additions & 9 deletions python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import ray
from ray._private.utils import get_or_create_event_loop
from ray.actor import ActorHandle
from ray.exceptions import RayActorError, RayTaskError
from ray.serve._private.common import (
DeploymentID,
Expand Down Expand Up @@ -708,7 +707,9 @@ async def send_request_to_replica(
) -> ResponseGenerator:
handle_arg = proxy_request.request_object()
response_generator = ProxyResponseGenerator(
handle.remote(handle_arg),
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
handle.remote(pickle.dumps(handle_arg)),
timeout_s=self.request_timeout_s,
)

Expand Down Expand Up @@ -765,8 +766,8 @@ def __init__(
node_ip_address: str,
is_head: bool,
proxy_router: ProxyRouter,
self_actor_name: str,
request_timeout_s: Optional[float] = None,
proxy_actor: Optional[ActorHandle] = None,
):
super().__init__(
node_id,
Expand All @@ -775,7 +776,7 @@ def __init__(
proxy_router,
request_timeout_s=request_timeout_s,
)
self.self_actor_handle = proxy_actor or ray.get_runtime_context().current_actor
self.self_actor_name = self_actor_name
self.asgi_receive_queues: Dict[str, MessageQueue] = dict()

@property
Expand Down Expand Up @@ -948,13 +949,16 @@ async def send_request_to_replica(
the status code.
"""
if app_is_cross_language:
handle_arg = await self._format_handle_arg_for_java(proxy_request)
handle_arg_bytes = await self._format_handle_arg_for_java(proxy_request)
# Response is returned as raw bytes, convert it to ASGI messages.
result_callback = convert_object_to_asgi_messages
else:
self_actor_handle = self.self_actor_handle
handle_arg = proxy_request.request_object(
receive_asgi_messages=self_actor_handle.receive_asgi_messages.remote
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
handle_arg_bytes = pickle.dumps(
proxy_request.request_object(
proxy_actor_name=self.self_actor_name,
)
)
# Messages are returned as pickled dictionaries.
result_callback = pickle.loads
Expand All @@ -969,7 +973,7 @@ async def send_request_to_replica(
)

response_generator = ProxyResponseGenerator(
handle.remote(handle_arg),
handle.remote(handle_arg_bytes),
timeout_s=self.request_timeout_s,
disconnected_task=proxy_asgi_receive_task,
result_callback=result_callback,
Expand Down Expand Up @@ -1218,6 +1222,7 @@ def __init__(
node_id=node_id,
node_ip_address=node_ip_address,
is_head=is_head,
self_actor_name=ray.get_runtime_context().get_actor_name(),
proxy_router=self.proxy_router,
request_timeout_s=(
request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S
Expand Down Expand Up @@ -1459,6 +1464,10 @@ async def check_health(self):
"""
logger.debug("Received health check.", extra={"log_to_stderr": False})

def pong(self):
"""Called by the replica to initialize its handle to the proxy."""
pass

async def receive_asgi_messages(self, request_metadata: RequestMetadata) -> bytes:
"""Get ASGI messages for the provided `request_metadata`.

Expand Down
7 changes: 4 additions & 3 deletions python/ray/serve/_private/proxy_request_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator, Awaitable, Callable, List, Tuple, Union
from typing import Any, AsyncIterator, List, Tuple, Union

import grpc
from starlette.types import Receive, Scope, Send
Expand Down Expand Up @@ -96,11 +96,12 @@ def set_root_path(self, root_path: str):
self.scope["root_path"] = root_path

def request_object(
self, receive_asgi_messages: Callable[[str], Awaitable[bytes]]
self,
proxy_actor_name: str,
) -> StreamingHTTPRequest:
return StreamingHTTPRequest(
asgi_scope=self.scope,
receive_asgi_messages=receive_asgi_messages,
proxy_actor_name=proxy_actor_name,
)


Expand Down
27 changes: 23 additions & 4 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,9 @@ async def __init__(
)

def push_proxy_handle(self, handle: ActorHandle):
pass
# NOTE(edoakes): it's important to call a method on the proxy handle to
# initialize its state in the C++ core worker.
handle.pong.remote()

def get_num_ongoing_requests(self) -> int:
"""Fetch the number of ongoing requests at this replica (queue length).
Expand Down Expand Up @@ -896,14 +898,27 @@ async def reconfigure(
await self._replica_impl.reconfigure(deployment_config)
return self._replica_impl.get_metadata()

def _preprocess_request_args(
self,
pickled_request_metadata: bytes,
request_args: Tuple[Any],
) -> Tuple[RequestMetadata, Tuple[Any]]:
request_metadata = pickle.loads(pickled_request_metadata)
if request_metadata.is_http_request or request_metadata.is_grpc_request:
request_args = (pickle.loads(request_args[0]),)

return request_metadata, request_args

async def handle_request(
self,
pickled_request_metadata: bytes,
*request_args,
**request_kwargs,
) -> Tuple[bytes, Any]:
"""Entrypoint for `stream=False` calls."""
request_metadata = pickle.loads(pickled_request_metadata)
request_metadata, request_args = self._preprocess_request_args(
pickled_request_metadata, request_args
)
return await self._replica_impl.handle_request(
request_metadata, *request_args, **request_kwargs
)
Expand All @@ -915,7 +930,9 @@ async def handle_request_streaming(
**request_kwargs,
) -> AsyncGenerator[Any, None]:
"""Generator that is the entrypoint for all `stream=True` handle calls."""
request_metadata = pickle.loads(pickled_request_metadata)
request_metadata, request_args = self._preprocess_request_args(
pickled_request_metadata, request_args
)
async for result in self._replica_impl.handle_request_streaming(
request_metadata, *request_args, **request_kwargs
):
Expand All @@ -939,7 +956,9 @@ async def handle_request_with_rejection(
For streaming requests, the subsequent messages will be the results of the
user request handler (which must be a generator).
"""
request_metadata = pickle.loads(pickled_request_metadata)
request_metadata, request_args = self._preprocess_request_args(
pickled_request_metadata, request_args
)
async for result in self._replica_impl.handle_request_with_rejection(
request_metadata, *request_args, **request_kwargs
):
Expand Down
16 changes: 3 additions & 13 deletions python/ray/serve/tests/unit/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ def cancel(self):
pass


class FakeActorHandle:
@property
def receive_asgi_messages(self):
class FakeReceiveASGIMessagesActorMethod:
def remote(self, request_id):
return FakeRef()

return FakeReceiveASGIMessagesActorMethod()


class FakeGrpcHandle:
def __init__(self, streaming: bool, grpc_context: RayServegRPCContext):
self.deployment_id = DeploymentID(
Expand Down Expand Up @@ -447,7 +437,7 @@ def create_http_proxy(self, is_head: bool = False):
node_ip_address=node_ip_address,
is_head=is_head,
proxy_router=FakeProxyRouter(),
proxy_actor=FakeActorHandle(),
self_actor_name="fake-proxy-name",
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -713,7 +703,7 @@ def get_handle_override(endpoint, info):
# proxy is on head node
is_head=True,
proxy_router=ProxyRouter(get_handle_override),
proxy_actor=FakeActorHandle(),
self_actor_name="fake-proxy-name",
)
proxy_request = FakeProxyRequest(
request_type="http",
Expand Down Expand Up @@ -757,7 +747,7 @@ async def test_worker_http_unhealthy_until_replicas_populated():
# proxy is on worker node
is_head=False,
proxy_router=ProxyRouter(lambda *args: handle),
proxy_actor=FakeActorHandle(),
self_actor_name="fake-proxy-name",
)
proxy_request = FakeProxyRequest(
request_type="http",
Expand Down
Loading