From 48e71268cb0d25f9c2dc3f45f130d6f3565ecbbe Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 5 Oct 2023 09:05:33 +0900 Subject: [PATCH] [Core] Fix segfault when cancel and generator is used with a high load #40083 (#40122) (#40126) Signed-off-by: Edward Oakes Co-authored-by: Edward Oakes Co-authored-by: SangBin Cho --- python/ray/_raylet.pyx | 18 ++-- .../ray/tests/test_streaming_generator_3.py | 94 +++++++++++++++++++ 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 5f1dd37415c7e..483441f686bec 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -4258,22 +4258,24 @@ cdef class CoreWorker: function_descriptor, specified_cgname) async def async_func(): - if task_id: - async_task_id.set(task_id) + try: + if task_id: + async_task_id.set(task_id) - if inspect.isawaitable(func_or_coro): - coroutine = func_or_coro - else: - coroutine = func_or_coro(*args, **kwargs) + if inspect.isawaitable(func_or_coro): + coroutine = func_or_coro + else: + coroutine = func_or_coro(*args, **kwargs) - return await coroutine + return await coroutine + finally: + event.Notify() future = asyncio.run_coroutine_threadsafe(async_func(), eventloop) if task_id: with self._task_id_to_future_lock: self._task_id_to_future[task_id] = future - future.add_done_callback(lambda _: event.Notify()) with nogil: (CCoreWorkerProcess.GetCoreWorker() .YieldCurrentFiber(event)) diff --git a/python/ray/tests/test_streaming_generator_3.py b/python/ray/tests/test_streaming_generator_3.py index 7190a22edbb7e..411c634a16014 100644 --- a/python/ray/tests/test_streaming_generator_3.py +++ b/python/ray/tests/test_streaming_generator_3.py @@ -5,9 +5,16 @@ import time from collections import Counter +from starlette.responses import StreamingResponse +from starlette.requests import Request +from fastapi import FastAPI +from ray import serve +from pydantic import BaseModel import ray from ray._raylet import StreamingObjectRefGenerator +from ray._private.test_utils import run_string_as_driver_nonblocking +from ray.util.state import list_actors def test_threaded_actor_generator(shutdown_only): @@ -231,6 +238,93 @@ def g(sleep_time): assert result[10] == 4 +def test_streaming_generator_load(shutdown_only): + app = FastAPI() + + @serve.deployment(max_concurrent_queries=1000) + @serve.ingress(app) + class Router: + def __init__(self, handle) -> None: + self._h = handle.options(stream=True) + self.total_recieved = 0 + + @app.get("/") + def stream_hi(self, request: Request) -> StreamingResponse: + async def consume_obj_ref_gen(): + obj_ref_gen = await self._h.hi_gen.remote() + start = time.time() + num_recieved = 0 + async for chunk in obj_ref_gen: + chunk = await chunk + num_recieved += 1 + yield str(chunk.json()) + delta = time.time() - start + print(f"**request throughput: {num_recieved / delta}") + + return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain") + + @serve.deployment(max_concurrent_queries=1000) + class SimpleGenerator: + async def hi_gen(self): + start = time.time() + for i in range(100): + # await asyncio.sleep(0.001) + time.sleep(0.001) # if change to async sleep, i don't see crash. + + class Model(BaseModel): + msg = "a" * 56 + + yield Model() + delta = time.time() - start + print(f"**model throughput: {100 / delta}") + + serve.run(Router.bind(SimpleGenerator.bind())) + + client_script = """ +import requests +import time +import io + +def send_serve_requests(): + request_meta = { + "request_type": "InvokeEndpoint", + "name": "Streamtest", + "start_time": time.time(), + "response_length": 0, + "response": None, + "context": {}, + "exception": None, + } + start_perf_counter = time.perf_counter() + #r = self.client.get("/", stream=True) + r = requests.get("http://localhost:8000", stream=True) + if r.status_code != 200: + print(r) + else: + for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)): + pass + request_meta["response_time"] = ( + time.perf_counter() - start_perf_counter + ) * 1000 + # events.request.fire(**request_meta) + +from concurrent.futures import ThreadPoolExecutor +with ThreadPoolExecutor(max_workers=10) as executor: + while True: + futs = [executor.submit(send_serve_requests) for _ in range(100)] + for f in futs: + f.result() +""" + for _ in range(5): + print("submit a new clients!") + proc = run_string_as_driver_nonblocking(client_script) + # Wait sufficient time. + time.sleep(5) + proc.terminate() + for actor in list_actors(): + assert actor.state != "DEAD" + + if __name__ == "__main__": import os