Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Add FF to run sync methods in a threadpool #48897

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,17 @@
RAY_SERVE_FORCE_LOCAL_TESTING_MODE = (
os.environ.get("RAY_SERVE_FORCE_LOCAL_TESTING_MODE", "0") == "1"
)

# Run sync methods defined in the replica in a thread pool by default.
RAY_SERVE_RUN_SYNC_IN_THREADPOOL = (
os.environ.get("RAY_SERVE_RUN_SYNC_IN_THREADPOOL", "0") == "1"
)

RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING = (
"Calling sync method '{method_name}' directly on the "
"asyncio loop. In a future version, sync methods will be run in a "
"threadpool by default. Ensure your sync methods are thread safe "
"or keep the existing behavior by making them `async def`. Opt "
"into the new behavior by setting "
"RAY_SERVE_RUN_SYNC_IN_THREADPOOL=1."
)
6 changes: 5 additions & 1 deletion python/ray/serve/_private/local_testing_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import ray
from ray import cloudpickle
from ray.serve._private.common import DeploymentID, RequestMetadata
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.constants import (
RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
SERVE_LOGGER_NAME,
)
from ray.serve._private.replica import UserCallableWrapper
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.router import Router
Expand Down Expand Up @@ -66,6 +69,7 @@ def make_local_deployment_handle(
deployment.init_args,
deployment.init_kwargs,
deployment_id=deployment_id,
run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
)
try:
logger.info(f"Initializing local replica class for {deployment_id}.")
Expand Down
145 changes: 127 additions & 18 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import concurrent.futures
import functools
import inspect
import logging
import os
import pickle
import threading
import time
import traceback
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from importlib import import_module
from typing import (
Any,
Expand All @@ -23,6 +24,7 @@
)

import starlette.responses
from anyio import to_thread
from starlette.types import ASGIApp, Message

import ray
Expand All @@ -47,6 +49,8 @@
HEALTH_CHECK_METHOD,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S,
RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING,
RECONFIGURE_METHOD,
SERVE_CONTROLLER_NAME,
SERVE_LOGGER_NAME,
Expand Down Expand Up @@ -274,6 +278,7 @@ def __init__(
init_args,
init_kwargs,
deployment_id=self._deployment_id,
run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
)

# Guards against calling the user's callable constructor multiple times.
Expand Down Expand Up @@ -602,6 +607,11 @@ async def initialize(self, deployment_config: DeploymentConfig):
self._user_callable_initialized = True

if deployment_config:
await asyncio.wrap_future(
self._user_callable_wrapper.set_sync_method_threadpool_limit(
deployment_config.max_ongoing_requests
)
)
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
Expand Down Expand Up @@ -635,6 +645,11 @@ async def reconfigure(self, deployment_config: DeploymentConfig):
if logging_config_changed:
self._configure_logger_and_profilers(deployment_config.logging_config)

await asyncio.wrap_future(
self._user_callable_wrapper.set_sync_method_threadpool_limit(
deployment_config.max_ongoing_requests
)
)
if user_config_changed:
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
Expand Down Expand Up @@ -990,6 +1005,7 @@ def __init__(
init_kwargs: Dict,
*,
deployment_id: DeploymentID,
run_sync_methods_in_threadpool: bool,
):
if not (inspect.isfunction(deployment_def) or inspect.isclass(deployment_def)):
raise TypeError(
Expand All @@ -1003,6 +1019,8 @@ def __init__(
self._is_function = inspect.isfunction(deployment_def)
self._deployment_id = deployment_id
self._destructor_called = False
self._run_sync_methods_in_threadpool = run_sync_methods_in_threadpool
self._warned_about_sync_method_change = False

# Will be populated in `initialize_callable`.
self._callable = None
Expand Down Expand Up @@ -1033,7 +1051,7 @@ def _run_on_user_code_event_loop(f: Callable) -> Callable:
f
), "_run_on_user_code_event_loop can only be used on coroutine functions."

@wraps(f)
@functools.wraps(f)
def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
f(self, *args, **kwargs),
Expand All @@ -1042,6 +1060,12 @@ def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:

return wrapper

@_run_on_user_code_event_loop
async def set_sync_method_threadpool_limit(self, limit: int):
# NOTE(edoakes): the limit is thread local, so this must
# be run on the user code event loop.
to_thread.current_default_thread_limiter().total_tokens = limit

def _get_user_callable_method(self, method_name: str) -> Callable:
if self._is_function:
return self._callable
Expand Down Expand Up @@ -1082,17 +1106,89 @@ async def _send_user_result_over_asgi(
else:
await Response(result).send(scope, receive, send)

async def _call_func_or_gen(self, callable: Callable, *args, **kwargs) -> Any:
async def _call_func_or_gen(
self,
callable: Callable,
*,
args: Optional[Tuple[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
request_metadata: Optional[RequestMetadata] = None,
generator_result_callback: Optional[Callable] = None,
run_sync_methods_in_threadpool_override: Optional[bool] = None,
) -> Tuple[Any, bool]:
"""Call the callable with the provided arguments.

This is a convenience wrapper that will work for `def`, `async def`,
generator, and async generator functions.

Returns the result and a boolean indicating if the result was a sync generator
that has already been consumed.
"""
result = callable(*args, **kwargs)
if inspect.iscoroutine(result):
result = await result
sync_gen_consumed = False
args = args if args is not None else tuple()
kwargs = kwargs if kwargs is not None else dict()
run_sync_in_threadpool = (
self._run_sync_methods_in_threadpool
if run_sync_methods_in_threadpool_override is None
else run_sync_methods_in_threadpool_override
)
is_sync_method = (
inspect.isfunction(callable) or inspect.ismethod(callable)
) and not (
inspect.iscoroutinefunction(callable)
or inspect.isasyncgenfunction(callable)
)

return result
if is_sync_method and run_sync_in_threadpool:
is_generator = inspect.isgeneratorfunction(callable)
if is_generator:
sync_gen_consumed = True
if request_metadata and not request_metadata.is_streaming:
# TODO(edoakes): make this check less redundant with the one in
# _handle_user_method_result.
raise TypeError(
f"Method '{callable.__name__}' returned a generator. "
"You must use `handle.options(stream=True)` to call "
"generators on a deployment."
)

def run_callable():
result = callable(*args, **kwargs)
if is_generator:
for r in result:
# TODO(edoakes): make this less redundant with the handling in
# _handle_user_method_result.
if request_metadata and request_metadata.is_grpc_request:
r = (request_metadata.grpc_context, r.SerializeToString())
generator_result_callback(r)

result = None

return result

# NOTE(edoakes): we use anyio.to_thread here because it's what Starlette
# uses (and therefore FastAPI too). The max size of the threadpool is
# set to max_ongoing_requests in the replica wrapper.
# anyio.to_thread propagates ContextVars to the worker thread automatically.
result = await to_thread.run_sync(run_callable)
else:
if (
is_sync_method
and not self._warned_about_sync_method_change
and run_sync_methods_in_threadpool_override is None
):
self._warned_about_sync_method_change = True
warnings.warn(
RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING.format(
method_name=callable.__name__,
)
)

result = callable(*args, **kwargs)
if inspect.iscoroutine(result):
result = await result

return result, sync_gen_consumed

@property
def user_callable(self) -> Optional[Callable]:
Expand Down Expand Up @@ -1129,8 +1225,10 @@ async def initialize_callable(self) -> Optional[ASGIApp]:
self._callable = self._deployment_def.__new__(self._deployment_def)
await self._call_func_or_gen(
self._callable.__init__,
*self._init_args,
**self._init_kwargs,
args=self._init_args,
kwargs=self._init_kwargs,
# Always run the constructor on the main user code thread.
run_sync_methods_in_threadpool_override=False,
)

if isinstance(self._callable, ASGIAppReplicaWrapper):
Expand Down Expand Up @@ -1192,7 +1290,7 @@ async def call_reconfigure(self, user_config: Any):
)
await self._call_func_or_gen(
getattr(self._callable, RECONFIGURE_METHOD),
user_config,
args=(user_config,),
)

def _prepare_args_for_http_request(
Expand Down Expand Up @@ -1264,6 +1362,7 @@ async def _handle_user_method_result(
user_method_name: str,
request_metadata: RequestMetadata,
*,
sync_gen_consumed: bool,
generator_result_callback: Optional[Callable],
is_asgi_app: bool,
asgi_args: Optional[ASGIArgs],
Expand Down Expand Up @@ -1297,7 +1396,7 @@ async def _handle_user_method_result(
# For the FastAPI codepath, the response has already been sent over
# ASGI, but for the vanilla deployment codepath we need to send it.
await self._send_user_result_over_asgi(result, asgi_args)
elif not request_metadata.is_http_request:
elif not request_metadata.is_http_request and not sync_gen_consumed:
# If a unary method is called with stream=True for anything EXCEPT
# an HTTP request, raise an error.
# HTTP requests are always streaming regardless of if the method
Expand Down Expand Up @@ -1382,12 +1481,20 @@ async def call_user_method(
request_args[0], request_metadata, user_method_params
)

result = await self._handle_user_method_result(
await self._call_func_or_gen(
user_method, *request_args, **request_kwargs
),
result, sync_gen_consumed = await self._call_func_or_gen(
user_method,
args=request_args,
kwargs=request_kwargs,
request_metadata=request_metadata,
generator_result_callback=generator_result_callback
if request_metadata.is_streaming
else None,
)
return await self._handle_user_method_result(
result,
user_method_name,
request_metadata,
sync_gen_consumed=sync_gen_consumed,
generator_result_callback=generator_result_callback,
is_asgi_app=is_asgi_app,
asgi_args=asgi_args,
Expand All @@ -1412,8 +1519,6 @@ async def call_user_method(
if receive_task is not None and not receive_task.done():
receive_task.cancel()

return result

@_run_on_user_code_event_loop
async def call_destructor(self):
"""Explicitly call the `__del__` method of the user callable.
Expand All @@ -1437,7 +1542,11 @@ async def call_destructor(self):
try:
if hasattr(self._callable, "__del__"):
# Make sure to accept `async def __del__(self)` as well.
await self._call_func_or_gen(self._callable.__del__)
await self._call_func_or_gen(
self._callable.__del__,
# Always run the destructor on the main user callable thread.
run_sync_methods_in_threadpool_override=False,
)

if hasattr(self._callable, "__serve_multiplex_wrapper"):
await getattr(self._callable, "__serve_multiplex_wrapper").shutdown()
Expand Down
22 changes: 22 additions & 0 deletions python/ray/serve/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,25 @@ py_test_module_list(
"//python/ray/serve:serve_lib",
],
)


# Test currently off-by-default behavior to run replica sync methods in a threadpool.
# TODO(edoakes): remove this once the FF is flipped on by default.
py_test_module_list(
size = "small",
env = {"RAY_SERVE_RUN_SYNC_IN_THREADPOOL": "1"},
files = [
"test_replica_sync_methods.py",
],
name_suffix = "_with_run_sync_in_threadpool",
tags = [
"exclusive",
"no_windows",
"team:serve",
],
deps = [
":common",
":conftest",
"//python/ray/serve:serve_lib",
],
)
Loading