Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Refactor image test assets #5821

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 70 additions & 41 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import contextlib
import gc
import os
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from collections import UserList
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypeVar)

import pytest
import torch
Expand All @@ -28,21 +33,8 @@
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]

# Multi modal related
# You can use `.buildkite/download-images.sh` to download the assets
PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
_IMAGE_DIR = Path(_TEST_DIR) / "images"
"""You can use `.buildkite/download-images.sh` to download the assets."""


def _read_prompts(filename: str) -> List[str]:
Expand All @@ -51,6 +43,63 @@ def _read_prompts(filename: str) -> List[str]:
return prompts


@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]

@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")

@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")

@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")

def for_hf(self) -> Image.Image:
return self.pil_image

def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)

raise NotImplementedError


class _ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str


class _ImageAssets(UserList[ImageAsset]):

def __init__(self) -> None:
super().__init__(
[ImageAsset("stop_sign"),
ImageAsset("cherry_blossom")])

def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
Copy link
Contributor

@xwjiang2010 xwjiang2010 Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: set_prompts to be clear.

edit: scratch that, I think the current name is ok

"""
Convenience method to define the prompt for each test image.
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return [prompts["stop_sign"], prompts["cherry_blossom"]]


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
Expand Down Expand Up @@ -81,31 +130,6 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup()


@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in IMAGE_FILES]


@pytest.fixture()
def vllm_images(request) -> List[MultiModalData]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
return [
ImageFeatureData(torch.load(filename))
for filename in IMAGE_FEATURES_FILES
]
else:
return [
ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
]


@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
return [torch.load(filename) for filename in PIXEL_VALUES_FILES]


@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
Expand All @@ -122,6 +146,11 @@ def example_long_prompts() -> List[str]:
return prompts


@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
return IMAGE_ASSETS


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
Expand Down
26 changes: 14 additions & 12 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@

from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_FILES
from ..conftest import IMAGE_ASSETS

pytestmark = pytest.mark.vlm

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"cherry_blossom":
"<image>\nUSER: What is the season?\nASSISTANT:",
]

assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
})


def iter_llava_configs(model_name: str):
Expand Down Expand Up @@ -49,28 +49,28 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids, output_str = vllm_output
output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id

tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)

hf_input_ids = [
input_id for idx, input_id in enumerate(input_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "")

return hf_input_ids, hf_output_str
return hf_output_ids, hf_output_str


# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test is under tests/images.
Expand All @@ -81,6 +81,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
Expand Down
30 changes: 16 additions & 14 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_FILES
from ..conftest import IMAGE_ASSETS

pytestmark = pytest.mark.vlm

Expand All @@ -15,12 +15,12 @@
"questions.")

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:",
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:",
]

assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:",
"cherry_blossom":
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:",
})


def iter_llava_next_configs(model_name: str):
Expand Down Expand Up @@ -56,20 +56,20 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids, output_str = vllm_output
output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id

tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)

hf_input_ids = [
input_id for idx, input_id in enumerate(input_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, " ")

return hf_input_ids, hf_output_str
return hf_output_ids, hf_output_str


@pytest.mark.xfail(
Expand All @@ -78,8 +78,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test is under tests/images.
Expand All @@ -90,6 +90,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
Expand Down
28 changes: 15 additions & 13 deletions tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from vllm.config import VisionLanguageConfig
from vllm.utils import is_cpu

from ..conftest import IMAGE_FILES
from ..conftest import IMAGE_ASSETS

pytestmark = pytest.mark.vlm

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
]

assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
"cherry_blossom":
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501
})


def iter_phi3v_configs(model_name: str):
Expand Down Expand Up @@ -50,22 +50,22 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids, output_str = vllm_output
output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id

tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)

hf_input_ids = [
input_id if input_id != image_token_id else 0
for idx, input_id in enumerate(input_ids)
hf_output_ids = [
token_id if token_id != image_token_id else 0
for idx, token_id in enumerate(output_ids)
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") \
.replace("<s>", " ").replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ")

return hf_input_ids, hf_output_str
return hf_output_ids, hf_output_str


target_dtype = "half"
Expand All @@ -82,8 +82,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test is under tests/images.
Expand All @@ -94,6 +94,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
Expand Down
Loading
Loading