Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] add vllm_test_utils #10659

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ ADD . /vllm-workspace/
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -e tests/vllm_test_utils

# enable fast downloads from hf (for testing)
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install hf_transfer
Expand Down
23 changes: 16 additions & 7 deletions tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys

from vllm_test_utils import blame

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory


def test_lazy_outlines(sample_regex):
"""If users don't use guided decoding, outlines should not be imported.
"""
def run_normal():
prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -25,13 +25,12 @@ def test_lazy_outlines(sample_regex):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# make sure outlines is not imported
assert 'outlines' not in sys.modules

# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()


def run_lmfe(sample_regex):
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
Expand All @@ -51,5 +50,15 @@ def test_lazy_outlines(sample_regex):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


def test_lazy_outlines(sample_regex):
"""If users don't use guided decoding, outlines should not be imported.
"""
# make sure outlines is not imported
assert 'outlines' not in sys.modules
module_name = "outlines"
with blame(lambda: module_name in sys.modules) as result:
run_normal()
run_lmfe(sample_regex)
assert not result.found, (
f"Module {module_name} is already imported, the"
f" first import location is:\n{result.trace_stack}")
54 changes: 1 addition & 53 deletions tests/test_lazy_torch_compile.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,9 @@
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script

import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator


@dataclasses.dataclass
class BlameResult:
found: bool = False
trace_stack: str = ""


@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.

Usage:

```python
with blame(lambda: some_condition()) as result:
# do something

if result.found:
print(result.trace_stack)
"""
result = BlameResult()

def _trace_calls(frame, event, arg=None):
nonlocal result
if event in ['call', 'return']:
# for every function call or return
try:
# Temporarily disable the trace function
sys.settrace(None)
# check condition here
if not result.found and func():
result.found = True
result.trace_stack = "".join(traceback.format_stack())
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

sys.settrace(_trace_calls)

yield result

sys.settrace(None)

from vllm_test_utils import blame

module_name = "torch._inductor.async_compile"

Expand Down
7 changes: 7 additions & 0 deletions tests/vllm_test_utils/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from setuptools import setup

setup(
name='vllm_test_utils',
version='0.1',
packages=['vllm_test_utils'],
)
8 changes: 8 additions & 0 deletions tests/vllm_test_utils/vllm_test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
vllm_utils is a package for vLLM testing utilities.
It does not import any vLLM modules.
"""

from .blame import BlameResult, blame

__all__ = ["blame", "BlameResult"]
53 changes: 53 additions & 0 deletions tests/vllm_test_utils/vllm_test_utils/blame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator


@dataclasses.dataclass
class BlameResult:
found: bool = False
trace_stack: str = ""


@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.

Usage:

```python
with blame(lambda: some_condition()) as result:
# do something

if result.found:
print(result.trace_stack)
"""
result = BlameResult()

def _trace_calls(frame, event, arg=None):
nonlocal result
if event in ['call', 'return']:
# for every function call or return
try:
# Temporarily disable the trace function
sys.settrace(None)
# check condition here
if not result.found and func():
result.found = True
result.trace_stack = "".join(traceback.format_stack())
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

sys.settrace(_trace_calls)

yield result

sys.settrace(None)