Skip to content

Commit

Permalink
[VLM] Fully dynamic prompt replacement in merged input processor (vll…
Browse files Browse the repository at this point in the history
…m-project#11199)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and BKitor committed Dec 30, 2024
1 parent b767c7a commit fe97c18
Show file tree
Hide file tree
Showing 12 changed files with 569 additions and 510 deletions.
5 changes: 1 addition & 4 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.

# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.

# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame.
Expand All @@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor


@pytest.fixture()
Expand Down
105 changes: 54 additions & 51 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import cast

import pytest
from transformers import BatchFeature

from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand All @@ -16,7 +16,7 @@
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
([], [], []),
([], [32000], []),
(
[32000, 32000, 32000],
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000],
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_1": [],
"pattern_2": [],
}
),
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
Expand Down Expand Up @@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!",
},
{
# Test whether target is confused with repl_unit
"pattern_1": ("<image><image>", 1),
# Test empty repl_unit
"pattern_2": ("", 1),
# Test multiple repl_count
"pattern_3": ("?", 2),
# Test whether target is confused with replacement
"pattern_1": "<image><image>",
# Test empty replacement
"pattern_2": "",
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
),
]
Expand All @@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"),
(2, "<image><image><image><image><image>??"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
Expand All @@ -306,17 +306,16 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)

result = replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)

# Only displayed on error
Expand All @@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918],
},
{
# Test whether target is confused with repl_unit
"pattern_1": ([32000, 32000], 1),
# Test empty repl_unit
"pattern_2": ([], 1),
# Test multiple repl_count
"pattern_3": ([1550], 2),
# Test whether target is confused with replacement
"pattern_1": [32000, 32000],
# Test empty replacement
"pattern_2": [],
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
),
]
Expand All @@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
Expand All @@ -373,17 +372,16 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)

result = replace_token_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)

# Only displayed on error
Expand All @@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key",
[
{
"pattern_1": ([32000, 32000], 1),
"pattern_2": ([], 1),
"pattern_3": ([1550], 2),
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
},
],
)
Expand All @@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
],
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=2,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
Expand All @@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())

prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer)
PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items()
]

result = list(iter_placeholders(prompt_repls, prompt))
result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))

# Only displayed on error
print("result:", result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch

from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):

def compute_logits(
Expand Down
Loading

0 comments on commit fe97c18

Please sign in to comment.