Skip to content

Commit

Permalink
[Core] Fix segfault when cancel and generator is used with a high load
Browse files Browse the repository at this point in the history
…#40083 (#40122) (#40126)

Signed-off-by: Edward Oakes <[email protected]>
Co-authored-by: Edward Oakes <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2023
1 parent cb4c3e4 commit 48e7126
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 8 deletions.
18 changes: 10 additions & 8 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4258,22 +4258,24 @@ cdef class CoreWorker:
function_descriptor, specified_cgname)

async def async_func():
if task_id:
async_task_id.set(task_id)
try:
if task_id:
async_task_id.set(task_id)

if inspect.isawaitable(func_or_coro):
coroutine = func_or_coro
else:
coroutine = func_or_coro(*args, **kwargs)
if inspect.isawaitable(func_or_coro):
coroutine = func_or_coro
else:
coroutine = func_or_coro(*args, **kwargs)

return await coroutine
return await coroutine
finally:
event.Notify()

future = asyncio.run_coroutine_threadsafe(async_func(), eventloop)
if task_id:
with self._task_id_to_future_lock:
self._task_id_to_future[task_id] = future

future.add_done_callback(lambda _: event.Notify())
with nogil:
(CCoreWorkerProcess.GetCoreWorker()
.YieldCurrentFiber(event))
Expand Down
94 changes: 94 additions & 0 deletions python/ray/tests/test_streaming_generator_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
import time

from collections import Counter
from starlette.responses import StreamingResponse
from starlette.requests import Request
from fastapi import FastAPI
from ray import serve

from pydantic import BaseModel
import ray
from ray._raylet import StreamingObjectRefGenerator
from ray._private.test_utils import run_string_as_driver_nonblocking
from ray.util.state import list_actors


def test_threaded_actor_generator(shutdown_only):
Expand Down Expand Up @@ -231,6 +238,93 @@ def g(sleep_time):
assert result[10] == 4


def test_streaming_generator_load(shutdown_only):
app = FastAPI()

@serve.deployment(max_concurrent_queries=1000)
@serve.ingress(app)
class Router:
def __init__(self, handle) -> None:
self._h = handle.options(stream=True)
self.total_recieved = 0

@app.get("/")
def stream_hi(self, request: Request) -> StreamingResponse:
async def consume_obj_ref_gen():
obj_ref_gen = await self._h.hi_gen.remote()
start = time.time()
num_recieved = 0
async for chunk in obj_ref_gen:
chunk = await chunk
num_recieved += 1
yield str(chunk.json())
delta = time.time() - start
print(f"**request throughput: {num_recieved / delta}")

return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain")

@serve.deployment(max_concurrent_queries=1000)
class SimpleGenerator:
async def hi_gen(self):
start = time.time()
for i in range(100):
# await asyncio.sleep(0.001)
time.sleep(0.001) # if change to async sleep, i don't see crash.

class Model(BaseModel):
msg = "a" * 56

yield Model()
delta = time.time() - start
print(f"**model throughput: {100 / delta}")

serve.run(Router.bind(SimpleGenerator.bind()))

client_script = """
import requests
import time
import io
def send_serve_requests():
request_meta = {
"request_type": "InvokeEndpoint",
"name": "Streamtest",
"start_time": time.time(),
"response_length": 0,
"response": None,
"context": {},
"exception": None,
}
start_perf_counter = time.perf_counter()
#r = self.client.get("/", stream=True)
r = requests.get("http://localhost:8000", stream=True)
if r.status_code != 200:
print(r)
else:
for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)):
pass
request_meta["response_time"] = (
time.perf_counter() - start_perf_counter
) * 1000
# events.request.fire(**request_meta)
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
while True:
futs = [executor.submit(send_serve_requests) for _ in range(100)]
for f in futs:
f.result()
"""
for _ in range(5):
print("submit a new clients!")
proc = run_string_as_driver_nonblocking(client_script)
# Wait sufficient time.
time.sleep(5)
proc.terminate()
for actor in list_actors():
assert actor.state != "DEAD"


if __name__ == "__main__":
import os

Expand Down

0 comments on commit 48e7126

Please sign in to comment.