Skip to content

Commit

Permalink
Fix possible premature generator completion; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Aug 21, 2024
1 parent 10a88ec commit db8aebc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 23 deletions.
101 changes: 88 additions & 13 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import asyncio
import os
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Optional

import pytest
import pytest_asyncio
import torch

from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput

from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear


Expand Down Expand Up @@ -118,33 +123,103 @@ async def test_new_requests_event():
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")


def test_asyncio_run():
def start_engine():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)

engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m"))
return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))


@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False


@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):

async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)

async for output in engine.generate(prompt,
sampling_params,
request_id=prompt):
async for output in async_engine.generate(prompt,
sampling_params,
request_id=prompt):
final_output = output
return final_output

async def generate():
return await asyncio.gather(
run("test0"),
run("test1"),
)

results = asyncio.run(generate())
results = await asyncio.gather(
run("test0"),
run("test1"),
)
assert len(results) == 2


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)

i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")

assert i == 5


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)

stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1

assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished
23 changes: 13 additions & 10 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)

import torch
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -85,9 +85,8 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:

def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
if not self._finished:
self._queue.put_nowait(item)

def finish(
self,
Expand All @@ -96,7 +95,7 @@ def finish(
if not self._finished:
self._finished = True
self._queue.put_nowait(
exception if exception is not None else STOP_ITERATION)
exception if self._is_raisable(exception) else STOP_ITERATION)

@property
def finished(self) -> bool:
Expand All @@ -106,11 +105,9 @@ async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while not self._finished:
while True:
result = await self._queue.get()
if isinstance(result, BaseException) or \
(isinstance(result, type) and \
issubclass(result, BaseException)):
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
Expand All @@ -119,6 +116,12 @@ async def generator(
self._cancel(self.request_id)
raise asyncio.CancelledError from None

@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))


class RequestTracker:
"""Synchronous abstraction for tracking requests."""
Expand Down

0 comments on commit db8aebc

Please sign in to comment.