Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix segfault when cancel and generator is used with a high load #40083

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions locustfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import requests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

import time
import io


from locust import HttpUser, task, constant, events

class CodeGenClient(HttpUser):

wait_time = constant(1)
@task
def send_serve_requests(self):
request_meta = {
"request_type": "InvokeEndpoint",
"name": "Streamtest",
"start_time": time.time(),
"response_length": 0,
"response": None,
"context": {},
"exception": None,
}
start_perf_counter = time.perf_counter()
#r = self.client.get("/", stream=True)
r = requests.get("http://localhost:8000", stream=True)
if r.status_code != 200:
print(r)
else:
for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)):
pass
request_meta["response_time"] = (
time.perf_counter() - start_perf_counter
) * 1000
events.request.fire(**request_meta)
96 changes: 61 additions & 35 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ cdef class StreamingGeneratorExecutionContext:
self.streaming_generator_returns = streaming_generator_returns
self.is_retryable_error = is_retryable_error
self.application_error = application_error
self.should_retry_exceptions, = should_retry_exceptions,
self.should_retry_exceptions = should_retry_exceptions
return self


Expand Down Expand Up @@ -1123,7 +1123,6 @@ cdef report_streaming_generator_output(
# Del output here so that we can GC the memory
# usage asap.
del output_or_exception

context.streaming_generator_returns[0].push_back(
c_pair[CObjectID, c_bool](
return_obj.first,
Expand Down Expand Up @@ -1177,7 +1176,8 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context


async def execute_streaming_generator_async(
context: StreamingGeneratorExecutionContext):
context: StreamingGeneratorExecutionContext,
coroutine_complete_event: asyncio.Event):
"""Execute a given generator and streaming-report the
result to the given caller_address.

Expand All @@ -1196,36 +1196,43 @@ async def execute_streaming_generator_async(

Args:
context: The context to execute streaming generator.
coroutine_complete_event: The asyncio.Event to notify the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and the type hint say asyncio.Event but a threading.Event is passed in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I tried with non-blocking appraoch, and it didn't work, and I changed to blocking way. Let's just go with #40122 as it makes more sense to me..

main thread that the coroutine is actually finished.
"""
assert context.is_initialized()
# Generator task should only have 1 return object ref,
# which contains None or exceptions (if system error occurs).
assert context.return_size == 1

gen = context.generator
while True:
try:
output_or_exception = await gen.__anext__()
except StopAsyncIteration:
break
except AsyncioActorExit:
# The execute_task will handle this case.
raise
except Exception as e:
output_or_exception = e

loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker
# Run it in a separate thread to that we can
# avoid blocking the event loop when serializing
# the output (which has nogil).
done = await loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
output_or_exception,
context)
if done:
break
try:
while True:
try:
output_or_exception = await gen.__anext__()
except StopAsyncIteration:
break
except AsyncioActorExit:
# The execute_task will handle this case.
raise
except Exception as e:
output_or_exception = e

loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker

# Run it in a separate thread to that we can
# avoid blocking the event loop when serializing
# the output (which has nogil).
done = await loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
output_or_exception,
context)
if done:
break
finally:
# Notify that the coroutine is actually finished.
coroutine_complete_event.set()


cdef create_generator_return_obj(
Expand Down Expand Up @@ -1665,13 +1672,31 @@ cdef void execute_task(
context.initialize(outputs)

if is_async_gen:
# Note that the report RPCs are called inside an
# event loop thread.
core_worker.run_async_func_or_coro_in_event_loop(
execute_streaming_generator_async(context),
function_descriptor,
name_of_concurrency_group_to_execute,
task_id)
coroutine_complete_event = threading.Event()

try:
# Note that the report RPCs are called inside an
# event loop thread.
core_worker.run_async_func_or_coro_in_event_loop(
execute_streaming_generator_async(
context, coroutine_complete_event),
function_descriptor,
name_of_concurrency_group_to_execute,
task_id)
except TaskCancelledError:
# Due to Python's limitation,
# execute_streaming_generator_async
# can return and raise an exception while coroutine
# is still running. wait for
# coroutine_complete_event to avoid it.
# TODO(sang): Currently, it is a blocking call.
# Just in case, we limit this to 1 second.
# However, this shouldn't block the thread long time
# because it only happens upon cancel, and
# when TaskCancelledError is raised, the task
# is already cancelled.
coroutine_complete_event.wait(1)
raise
else:
execute_streaming_generator_sync(context)

Expand Down Expand Up @@ -2235,7 +2260,7 @@ cdef void cancel_async_task(
function_descriptor, name_of_concurrency_group_to_execute)
future = worker.core_worker.get_queued_future(task_id)
if future is not None:
future.cancel()
eventloop.call_soon_threadsafe(future.cancel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why's this change needed btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to reduce the impact of the blocking call (so when things are cancelled, the coroutine is already context switched). I think with the new fix, it is not necessray

# else, the task is already finished. If the task
# wasn't finished (task is queued on a client or server side),
# this method shouldn't have been called.
Expand Down Expand Up @@ -4276,7 +4301,7 @@ cdef class CoreWorker:
# transport with max_concurrency flag.
increase_recursion_limit()

eventloop, async_thread = self.get_event_loop(
eventloop, _ = self.get_event_loop(
function_descriptor, specified_cgname)

async def async_func():
Expand All @@ -4299,6 +4324,7 @@ cdef class CoreWorker:
with nogil:
(CCoreWorkerProcess.GetCoreWorker()
.YieldCurrentFiber(event))

try:
result = future.result()
except concurrent.futures.CancelledError:
Expand Down
94 changes: 94 additions & 0 deletions python/ray/tests/test_streaming_generator_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
import time

from collections import Counter
from starlette.responses import StreamingResponse
from starlette.requests import Request
from fastapi import FastAPI
from ray import serve

from pydantic import BaseModel
import ray
from ray._raylet import StreamingObjectRefGenerator
from ray._private.test_utils import run_string_as_driver_nonblocking
from ray.util.state import list_actors


def test_threaded_actor_generator(shutdown_only):
Expand Down Expand Up @@ -231,6 +238,93 @@ def g(sleep_time):
assert result[10] == 4


def test_streaming_generator_load(shutdown_only):
app = FastAPI()

@serve.deployment(max_concurrent_queries=1000)
@serve.ingress(app)
class Router:
def __init__(self, handle) -> None:
self._h = handle.options(stream=True)
self.total_recieved = 0

@app.get("/")
def stream_hi(self, request: Request) -> StreamingResponse:
async def consume_obj_ref_gen():
obj_ref_gen = await self._h.hi_gen.remote()
start = time.time()
num_recieved = 0
async for chunk in obj_ref_gen:
chunk = await chunk
num_recieved += 1
yield str(chunk.json())
delta = time.time() - start
print(f"**request throughput: {num_recieved / delta}")

return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain")

@serve.deployment(max_concurrent_queries=1000)
class SimpleGenerator:
async def hi_gen(self):
start = time.time()
for i in range(100):
# await asyncio.sleep(0.001)
time.sleep(0.001) # if change to async sleep, i don't see crash.

class Model(BaseModel):
msg = "a" * 56

yield Model()
delta = time.time() - start
print(f"**model throughput: {100 / delta}")

serve.run(Router.bind(SimpleGenerator.bind()))

client_script = """
import requests
import time
import io

def send_serve_requests():
request_meta = {
"request_type": "InvokeEndpoint",
"name": "Streamtest",
"start_time": time.time(),
"response_length": 0,
"response": None,
"context": {},
"exception": None,
}
start_perf_counter = time.perf_counter()
#r = self.client.get("/", stream=True)
r = requests.get("http://localhost:8000", stream=True)
if r.status_code != 200:
print(r)
else:
for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)):
pass
request_meta["response_time"] = (
time.perf_counter() - start_perf_counter
) * 1000
# events.request.fire(**request_meta)

from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
while True:
futs = [executor.submit(send_serve_requests) for _ in range(100)]
for f in futs:
f.result()
"""
for _ in range(5):
print("submit a new clients!")
proc = run_string_as_driver_nonblocking(client_script)
# Wait sufficient time.
time.sleep(5)
proc.terminate()
for actor in list_actors():
assert actor.state != "DEAD"


if __name__ == "__main__":
import os

Expand Down