Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Fully dynamic prompt replacement in merged input processor #11199

Merged
merged 14 commits into from
Dec 14, 2024
5 changes: 1 addition & 4 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.

# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.

# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame.
Expand All @@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor


@pytest.fixture()
Expand Down
105 changes: 54 additions & 51 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import cast

import pytest
from transformers import BatchFeature

from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand All @@ -16,7 +16,7 @@
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
([], [], []),
([], [32000], []),
(
[32000, 32000, 32000],
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000],
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_1": [],
"pattern_2": [],
}
),
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
Expand Down Expand Up @@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!",
},
{
# Test whether target is confused with repl_unit
"pattern_1": ("<image><image>", 1),
# Test empty repl_unit
"pattern_2": ("", 1),
# Test multiple repl_count
"pattern_3": ("?", 2),
# Test whether target is confused with replacement
"pattern_1": "<image><image>",
# Test empty replacement
"pattern_2": "",
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
),
]
Expand All @@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"),
(2, "<image><image><image><image><image>??"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
Expand All @@ -306,17 +306,16 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)

result = replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)

# Only displayed on error
Expand All @@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918],
},
{
# Test whether target is confused with repl_unit
"pattern_1": ([32000, 32000], 1),
# Test empty repl_unit
"pattern_2": ([], 1),
# Test multiple repl_count
"pattern_3": ([1550], 2),
# Test whether target is confused with replacement
"pattern_1": [32000, 32000],
# Test empty replacement
"pattern_2": [],
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
),
]
Expand All @@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
Expand All @@ -373,17 +372,16 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)

result = replace_token_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)

# Only displayed on error
Expand All @@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key",
[
{
"pattern_1": ([32000, 32000], 1),
"pattern_2": ([], 1),
"pattern_3": ([1550], 2),
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
},
],
)
Expand All @@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
],
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=2,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
Expand All @@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer)
PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items()
]

result = list(iter_placeholders(prompt_repls, prompt))
result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))

# Only displayed on error
print("result:", result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch

from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):

def compute_logits(
Expand Down
Loading
Loading