From 93abf23a648051fe6dc053ba0b74499d119920bf Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 15 Dec 2024 01:52:18 +0800 Subject: [PATCH] [VLM] Fully dynamic prompt replacement in merged input processor (#11199) Signed-off-by: DarkLight1337 --- examples/offline_inference_vision_language.py | 5 +- .../mm_processor_kwargs/test_phi3v.py | 4 +- tests/multimodal/test_processing.py | 105 +-- .../vllm_add_dummy_model/my_llava.py | 4 +- vllm/inputs/registry.py | 71 +- vllm/model_executor/models/llava.py | 144 ++--- vllm/model_executor/models/phi3v.py | 118 ++-- vllm/model_executor/models/pixtral.py | 2 +- vllm/multimodal/base.py | 4 +- vllm/multimodal/processing.py | 606 +++++++++--------- vllm/multimodal/registry.py | 4 +- vllm/utils.py | 12 +- 12 files changed, 569 insertions(+), 510 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c430f42fdc814..45539c665a922 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str): # max_model_len (128k) for this model may cause OOM. # You may lower either to run this example on lower-end GPUs. - # In this example, we override max_num_seqs to 5 while - # keeping the original context length of 128k. - # num_crops is an override kwarg to the multimodal image processor; # For some models, e.g., Phi-3.5-vision-instruct, it is recommended # to use 16 for single frame scenarios, and 4 for multi-frame. @@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str): # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( - model="microsoft/Phi-3-vision-128k-instruct", + model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, max_num_seqs=2, diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index c16192a1e1438..ce8ac8d8e0ceb 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -16,8 +16,8 @@ # Wrap lazy imports to avoid initializing CUDA during test collection @pytest.fixture() def processor_for_phi3v(): - from vllm.model_executor.models.phi3v import Phi3VProcessor - return Phi3VProcessor + from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor + return Phi3VMultiModalProcessor @pytest.fixture() diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index ae668d1dd56c8..6aaa80ddc9fa5 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,11 +1,11 @@ from typing import cast import pytest -from transformers import BatchFeature -from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, - find_text_matches, find_token_matches, - iter_placeholders, iter_token_matches, +from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement, + _PlaceholderInfo, find_text_matches, + find_token_matches, iter_placeholders, + iter_token_matches, replace_text_matches, replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -16,7 +16,7 @@ @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), [ - ([], [], [{ "start_idx": 0, "end_idx": 0 }]), + ([], [], []), ([], [32000], []), ( [32000, 32000, 32000], @@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_2": [32000], }, { - "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_1": [], "pattern_2": [], } ), @@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): mock_tokenizer = cast(AnyTokenizer, object()) prompt_repls = [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + PromptReplacement(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] result = find_token_matches(prompt, prompt_repls) @@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): mock_tokenizer = cast(AnyTokenizer, object()) prompt_repls = [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + PromptReplacement(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] result = find_text_matches(prompt, prompt_repls) @@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): "pattern_3": "!", }, { - # Test whether target is confused with repl_unit - "pattern_1": ("", 1), - # Test empty repl_unit - "pattern_2": ("", 1), - # Test multiple repl_count - "pattern_3": ("?", 2), + # Test whether target is confused with replacement + "pattern_1": "", + # Test empty replacement + "pattern_2": "", + # Test dynamic replacement (beyond the form of `unit * count`) + "pattern_3": "?!?", }, ), ] @@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ("mm_count", "expected"), [ (0, "Image:Image:!"), - (1, "Image:??"), - (2, "??"), + (1, "Image:?!?"), + (2, "?!?"), ] ) # yapf: enable @@ -306,7 +306,7 @@ def test_find_replace_text( mock_tokenizer = cast(AnyTokenizer, object()) prompt_repls = [ - PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) for key, target in target_by_key.items() ] matches = find_text_matches(prompt, prompt_repls) @@ -314,9 +314,8 @@ def test_find_replace_text( result = replace_text_matches( prompt, matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), + MultiModalDataItems({key: [None] * mm_count + for key in repl_by_key}), ) # Only displayed on error @@ -343,12 +342,12 @@ def test_find_replace_text( "pattern_3": [918], }, { - # Test whether target is confused with repl_unit - "pattern_1": ([32000, 32000], 1), - # Test empty repl_unit - "pattern_2": ([], 1), - # Test multiple repl_count - "pattern_3": ([1550], 2), + # Test whether target is confused with replacement + "pattern_1": [32000, 32000], + # Test empty replacement + "pattern_2": [], + # Test dynamic replacement (beyond the form of `unit * count`) + "pattern_3": [1550, 918, 1550], }, ), ] @@ -357,8 +356,8 @@ def test_find_replace_text( ("mm_count", "expected"), [ (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), - (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), - (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), + (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]), + (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]), ] ) # yapf: enable @@ -373,7 +372,7 @@ def test_find_replace_tokens( mock_tokenizer = cast(AnyTokenizer, object()) prompt_repls = [ - PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) for key, target in target_by_key.items() ] matches = find_token_matches(prompt, prompt_repls) @@ -381,9 +380,8 @@ def test_find_replace_tokens( result = replace_token_matches( prompt, matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), + MultiModalDataItems({key: [None] * mm_count + for key in repl_by_key}), ) # Only displayed on error @@ -399,9 +397,9 @@ def test_find_replace_tokens( "repl_by_key", [ { - "pattern_1": ([32000, 32000], 1), - "pattern_2": ([], 1), - "pattern_3": ([1550], 2), + "pattern_1": [32000, 32000], + "pattern_2": [], + "pattern_3": [1550, 918, 1550], }, ], ) @@ -414,48 +412,47 @@ def test_find_replace_tokens( _PlaceholderInfo( modality="pattern_1", start_idx=6, - unit=[32000, 32000], - unit_count=1, + replacement=[32000, 32000], ), ], ), ( - [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], + [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], [ _PlaceholderInfo( modality="pattern_1", start_idx=1, - unit=[32000, 32000], - unit_count=1, + replacement=[32000, 32000], ), _PlaceholderInfo( modality="pattern_1", start_idx=5, - unit=[32000, 32000], - unit_count=1, + replacement=[32000, 32000], ), _PlaceholderInfo( modality="pattern_3", start_idx=7, - unit=[1550], - unit_count=2, + replacement=[1550, 918, 1550], ), ], ), ( - [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], + [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], [ _PlaceholderInfo( modality="pattern_1", start_idx=1, - unit=[32000, 32000], - unit_count=2, + replacement=[32000, 32000], + ), + _PlaceholderInfo( + modality="pattern_1", + start_idx=3, + replacement=[32000, 32000], ), _PlaceholderInfo( modality="pattern_3", start_idx=6, - unit=[1550], - unit_count=2, + replacement=[1550, 918, 1550], ), ], ), @@ -470,11 +467,17 @@ def test_iter_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) prompt_repls = [ - PromptReplacement([], *repl).bind(key, mock_tokenizer) + PromptReplacement(key, [], repl).bind(mock_tokenizer) for key, repl in repl_by_key.items() ] - result = list(iter_placeholders(prompt_repls, prompt)) + result = list( + iter_placeholders( + prompt_repls, + prompt, + # Effectively match all occurrences in the prompt + MultiModalDataItems({key: [None] * 3 for key in repl_by_key}), + )) # Only displayed on error print("result:", result) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 2f4194a63fc25..0d90635093ac7 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,14 +3,14 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - LlavaProcessor, + LlavaMultiModalProcessor, get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) +@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 0dfed3b7e61bf..0b85484c48714 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, - Optional, Protocol, Type, cast) + Optional, Protocol, Type) from torch import nn from transformers import PretrainedConfig, ProcessorMixin @@ -47,7 +47,6 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: Raises: TypeError: If the model is not of the specified type. """ - hf_config = self.model_config.hf_config if not isinstance(hf_config, hf_config_type): raise TypeError("Invalid type of HuggingFace config. " @@ -60,21 +59,70 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: """ Get the HuggingFace image processor configuration of the model. """ - return self.model_config.hf_image_processor_config + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + base_kwargs = self.model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return cached_get_processor( + self.model_config.model, + trust_remote_code=self.model_config.trust_remote_code, + **merged_kwargs, + ) + @dataclass(frozen=True) class InputProcessingContext(InputContext): tokenizer: AnyTokenizer """The tokenizer used to tokenize the inputs.""" - def get_hf_processor(self, **kwargs) -> ProcessorMixin: + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + base_kwargs = self.model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + return cached_get_processor( - self.model_config.tokenizer, + self.model_config.model, tokenizer=self.tokenizer, # Override the tokenizer with ours trust_remote_code=self.model_config.trust_remote_code, - **kwargs) + **merged_kwargs, + ) + + def resolve_hf_processor_call_kwargs( + self, + hf_processor: ProcessorMixin, + inference_kwargs: Mapping[str, object], + ) -> Mapping[str, object]: + assert callable(hf_processor) + + base_kwargs = self.model_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + return resolve_mm_processor_kwargs( + base_kwargs, + inference_kwargs, + hf_processor, + ) N = TypeVar("N", bound=Type[nn.Module]) @@ -171,7 +219,8 @@ def register_dummy_data(self, factory: DummyDataFactory): """ def wrapper(model_cls: N) -> N: - if model_cls in self._dummy_factories_by_model_type: + if self._dummy_factories_by_model_type.contains(model_cls, + strict=True): logger.warning( "Model class %s already has dummy data " "registered to %s. It is overwritten by the new one.", @@ -195,7 +244,8 @@ def register_dummy_encoder_data(self, factory: DummyDataFactory): """ def wrapper(model_cls: N) -> N: - if model_cls in self._dummy_encoder_factories_by_model_type: + if self._dummy_encoder_factories_by_model_type.contains( + model_cls, strict=True): logger.warning( "Model class %s already has dummy encoder data " "registered to %s. It is overwritten by the new one.", @@ -305,7 +355,8 @@ def register_input_processor(self, processor: InputProcessor): """ def wrapper(model_cls: N) -> N: - if model_cls in self._input_processors_by_model_type: + if self._input_processors_by_model_type.contains(model_cls, + strict=True): logger.warning( "Model class %s already has input processor " "registered to %s. It is overwritten by the new one.", @@ -357,7 +408,7 @@ def process_input(self, model_config: "ModelConfig", # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, - cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), + inputs.get("mm_processor_kwargs", {}), # type: ignore processor, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 53eef72dd5f91..a2e404cf43238 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -5,10 +5,10 @@ import torch import torch.nn as nn -from PIL.Image import Image from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, ProcessorMixin, SiglipVisionConfig) +from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.attention import AttentionMetadata @@ -21,11 +21,9 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext, - ModalityProcessingMetadata, - MultiModalProcessingMetadata, + MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors @@ -33,7 +31,8 @@ get_max_clip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - get_max_pixtral_hf_image_tokens) + get_max_pixtral_hf_image_tokens, + get_pixtral_hf_image_feature_size) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - if isinstance(vision_config, CLIPVisionConfig): - data = dummy_image_for_clip(vision_config, num_images) - elif isinstance(vision_config, SiglipVisionConfig): - data = dummy_image_for_siglip(vision_config, num_images) - elif isinstance(vision_config, PixtralVisionConfig): - data = dummy_image_for_pixtral_hf(vision_config, num_images) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - hf_processor = ctx.get_hf_processor() - image_processor = hf_processor.image_processor # type: ignore - hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") - is_pixtral = isinstance(hf_processor, PixtralProcessor) - - return MultiModalKwargs( - **hf_inputs, - is_pixtral=torch.tensor(is_pixtral), - ) - - -def create_metadata_for_llava( - ctx: InputProcessingContext) -> MultiModalProcessingMetadata: - hf_config = ctx.get_hf_config(LlavaConfig) - image_token_id = hf_config.image_token_index - - def get_repl_count( - mm_items: list[Image], - hf_inputs: BatchFeature, - item_idx: int, - ) -> int: - return get_max_llava_image_tokens(ctx) - - return { - "image": - ModalityProcessingMetadata(prompt_repls=[ - PromptReplacement(target=[image_token_id], - repl_unit=[image_token_id], - repl_count=get_repl_count), - ]), - } - - -class LlavaProcessor(BaseMultiModalProcessor): - - def __init__(self, ctx: InputProcessingContext) -> None: - super().__init__( - ctx=ctx, - metadata=create_metadata_for_llava(ctx), - ) +class LlavaMultiModalProcessor(BaseMultiModalProcessor): def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): if getattr(hf_processor, "__is_patched__", False): @@ -188,18 +132,72 @@ def preprocess(__self, *args, **kwargs): hf_processor.__is_patched__ = True # type: ignore - def _get_hf_processor(self) -> ProcessorMixin: + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: hf_processor = self.ctx.get_hf_processor() + assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor)) if isinstance(hf_processor, PixtralProcessor): self._patch_pixtral_processor(hf_processor) return hf_processor - def _get_dummy_mm_kwargs( + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_inputs: BatchFeature, + mm_processor_kwargs: Mapping[str, object], + ) -> list[PromptReplacement]: + hf_config = self.ctx.get_hf_config(LlavaConfig) + image_token_id = hf_config.image_token_index + + processor = self._get_hf_processor() + if isinstance(processor, PixtralProcessor): + image_token = processor.image_token + image_break_token = processor.image_break_token + image_end_token = processor.image_end_token + + vision_config = hf_config.vision_config + assert isinstance(vision_config, PixtralVisionConfig) + + def get_replacement_pixtral(item_idx: int): + image_size = mm_items.get_image_size(item_idx) + ( + num_width_tokens, + num_height_tokens, + ) = get_pixtral_hf_image_feature_size( + vision_config, + image_width=image_size.width, + image_height=image_size.height, + ) + + tokens = ([image_token] * num_width_tokens + + [image_break_token]) * num_height_tokens + tokens[-1] = image_end_token + + return "".join(tokens) + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_pixtral, + ), + ] + + max_image_tokens = get_max_llava_image_tokens(self.ctx) + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=[image_token_id] * max_image_tokens, + ) + ] + + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: + ) -> ProcessorInputs: hf_config = self.ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config num_images = mm_counts["image"] @@ -215,11 +213,13 @@ def _get_dummy_mm_kwargs( raise NotImplementedError(msg) hf_processor = self._get_hf_processor() - image_processor = hf_processor.image_processor # type: ignore - hf_inputs = image_processor.preprocess(data['image'], - return_tensors="pt") + image_token = hf_processor.image_token - return MultiModalKwargs(**hf_inputs) + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=data, + mm_processor_kwargs={}, + ) class LlavaLikeConfig(Protocol): @@ -303,7 +303,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) +@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -584,7 +584,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loader.load_weights(weights) -class MantisProcessor(LlavaProcessor): +class MantisMultiModalProcessor(LlavaMultiModalProcessor): def _get_hf_processor(self) -> ProcessorMixin: try: @@ -604,6 +604,6 @@ def _get_hf_processor(self) -> ProcessorMixin: # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) +@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 3c7854ce388ab..7ab06768ae612 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,13 +32,10 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext, - ModalityProcessingMetadata, MultiModalDataDict, - MultiModalProcessingMetadata, + MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -305,64 +302,17 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -def get_max_phi3v_image_tokens(ctx: InputContext, - *, - num_crops: Optional[int] = None): - mm_processor_kwargs = {} - if num_crops is not None: - mm_processor_kwargs["num_crops"] = num_crops +def get_max_phi3v_image_tokens(ctx: InputContext) -> int: + processor = ctx.get_hf_processor() + image_processor = processor.image_processor # type: ignore - model_config = ctx.model_config - image_processor = cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs, - ) - - num_tokens = image_processor.calc_num_image_tokens_from_image_size( + return image_processor.calc_num_image_tokens_from_image_size( width=MAX_IMAGE_FEATURE_SIZE_WIDTH, height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return num_tokens - - -def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - - data = dummy_image_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - num_images, - image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, - image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - ) - - hf_processor = ctx.get_hf_processor() - image_processor = hf_processor.image_processor # type: ignore - hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") - - return MultiModalKwargs(**hf_inputs) - - -def create_metadata_for_phi3v( - ctx: InputProcessingContext) -> MultiModalProcessingMetadata: - return { - "image": - ModalityProcessingMetadata(prompt_repls=[ - PromptReplacement(target=[_IMAGE_TOKEN_ID], - repl_unit=[_IMAGE_TOKEN_ID], - repl_count=get_max_phi3v_image_tokens(ctx)), - ]), - } - -class Phi3VProcessor(BaseMultiModalProcessor): - def __init__(self, ctx: InputProcessingContext) -> None: - super().__init__( - ctx=ctx, - metadata=create_metadata_for_phi3v(ctx), - ) +class Phi3VMultiModalProcessor(BaseMultiModalProcessor): def _get_hf_processor( self, @@ -389,15 +339,61 @@ def _apply_hf_processor( processed_outputs['input_ids'] = token_ids return processed_outputs - def _get_dummy_mm_kwargs( + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_inputs: BatchFeature, + mm_processor_kwargs: Mapping[str, object], + ) -> list[PromptReplacement]: + hf_processor = self._get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + image_processor = hf_processor.image_processor # type: ignore + + mm_config = self.ctx.get_mm_config() + max_images = mm_config.limit_per_prompt.get("image", 1) + + def get_replacement_phi3v(item_idx: int): + image_size = mm_items.get_image_size(item_idx) + num_tokens = image_processor.calc_num_image_tokens_from_image_size( + width=image_size.width, + height=image_size.height, + ) + + return [_IMAGE_TOKEN_ID] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_phi3v, + ) for image_token in image_tokens[:max_images] + ] + + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: - return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) + ) -> ProcessorInputs: + num_images = mm_counts["image"] + + data = dummy_image_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + num_images, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + + hf_processor = self._get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + + return ProcessorInputs( + prompt_text="".join(image_tokens[:num_images]), + mm_data=data, + mm_processor_kwargs={}, + ) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) +@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 94a4ab882c1a9..161d6b41bfa5f 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img - mm_config = ctx.model_config.multimodal_config + mm_config = ctx.get_mm_config() num_images = mm_config.limit_per_prompt.get("image", 1) # dummy size diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7dba94b885b6d..fe77a4635f7d8 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -99,7 +99,7 @@ def register_input_mapper( """ def wrapper(model_cls: N) -> N: - if model_cls in self._input_mappers: + if self._input_mappers.contains(model_cls, strict=True): logger.warning( "Model class %s already has an input mapper " "registered to %s. It is overwritten by the new one.", @@ -194,7 +194,7 @@ def register_max_multimodal_tokens( """ def wrapper(model_cls: N) -> N: - if model_cls in self._max_mm_tokens: + if self._max_mm_tokens.contains(model_cls, strict=True): logger.warning( "Model class %s already calculates maximum number of " "tokens in %s. It is overwritten by the new one.", diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 922c83b6fd8a9..de5a002d474c2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,116 +1,59 @@ import re from abc import ABC, abstractmethod +from collections import UserDict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache -from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union, cast) +from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union +import numpy as np import torch +from PIL.Image import Image from transformers import BatchFeature, ProcessorMixin -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import assert_never from vllm.inputs import DummyData, InputProcessingContext +from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of, - resolve_mm_processor_kwargs) +from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, VideoItem) +logger = init_logger(__name__) -def bind_prompt_sequence( - seq: Union[str, list[int]], - tokenizer: AnyTokenizer, -) -> "_BoundPromptSequence": - """ - Bind a text or token sequence to a tokenizer so that it can be - lazily converted into the other format on demand. - """ - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - - -_T = TypeVar("_T") _S = TypeVar("_S", str, list[int]) +_PromptSeq = Union[str, list[int]] @dataclass -class PromptReplacement(Generic[_S, _T]): - target: _S - """The text or token sequence to find and replace.""" +class PromptReplacement: + modality: str + """The modality for which the replacement is made""" - repl_unit: _S - """ - The unit making up the replacement text or token sequence. - - See :code:`repl_count` for more details. - """ + target: _PromptSeq + """The text or token sequence to find and replace.""" - repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] + replacement: Union[Callable[[int], _PromptSeq], + _PromptSeq] = field(repr=False) """ - Given the original multi-modal items for this modality, HF-processed data, - and index of the processed item, output the number of repetitions of - :code:`repl_unit` to build up the replacement text or token sequence. + Given the index of the processed item within :attr:`modality`, output the + replacement text or token sequence. - For convenience, you can pass in an integer if the number of repetitions is - a constant. + For convenience, you can pass in the replacement instead of a function + if it does not depend on the input. """ - def __repr__(self) -> str: - return (f"{type(self).__name__}(target={self.target!r}, " - f"repl_unit={self.repl_unit!r})") - - def bind( - self, - modality: str, - tokenizer: AnyTokenizer, - ) -> "_BoundPromptReplacement[_T]": + def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": return _BoundPromptReplacement( - modality=modality, - target=bind_prompt_sequence(self.target, tokenizer), - repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer), - repl_count=self.repl_count, + tokenizer=tokenizer, + modality=self.modality, + _target=self.target, + _replacement=self.replacement, ) -@dataclass -class ModalityProcessingMetadata(Generic[_T]): - prompt_repls: Sequence[Union[PromptReplacement[str, _T], - PromptReplacement[list[int], _T]]] - """ - Defines each text or token sequence to replace in the HF-processed prompt. - - This is skipped if the HF-processed prompt is found to already contain - the replacement prompts. - """ - - -class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): - """Type annotations for modality types predefined by vLLM.""" - - image: ModalityProcessingMetadata[ImageItem] - video: ModalityProcessingMetadata[VideoItem] - audio: ModalityProcessingMetadata[AudioItem] - - -MultiModalProcessingMetadata: TypeAlias = \ - Mapping[str, ModalityProcessingMetadata[Any]] -""" -A dictionary containing an entry for each modality type to process. - -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin - is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" - - def _encode( tokenizer: AnyTokenizer, text: str, @@ -185,7 +128,8 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: @dataclass class _BoundPromptSequence: - tokenizer: AnyTokenizer + tokenizer: AnyTokenizer = field(repr=False) + _text: Optional[str] _token_ids: Optional[list[int]] @@ -210,38 +154,92 @@ def token_ids(self) -> list[int]: return self._token_ids - def __repr__(self) -> str: - return (f"{type(self).__name__}(_text={self._text!r}, " - f"_token_ids={self._token_ids!r})") - @dataclass -class _BoundPromptReplacement(Generic[_T]): +class _BoundPromptReplacement: + tokenizer: AnyTokenizer = field(repr=False) modality: str - target: _BoundPromptSequence - repl_unit: _BoundPromptSequence - repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] - def get_count( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> int: - repl_count = self.repl_count - if isinstance(repl_count, int): - return repl_count + _target: _PromptSeq + _replacement: Union[Callable[[int], _PromptSeq], + _PromptSeq] = field(repr=False) - return repl_count(mm_items, hf_inputs, item_idx) + def __post_init__(self) -> None: + self._replacement_cache = dict[int, _BoundPromptSequence]() + + @property + def target(self) -> _BoundPromptSequence: + target = self._target + return _BoundPromptSequence( + tokenizer=self.tokenizer, + _text=target if isinstance(target, str) else None, + _token_ids=target if isinstance(target, list) else None, + ) -def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: + def get_replacement(self, item_idx: int) -> _BoundPromptSequence: + replacement = self._replacement + if callable(replacement): + cache_key = item_idx + if cache_key in self._replacement_cache: + return self._replacement_cache[cache_key] + + replacement = replacement(item_idx) + else: + cache_key = None + + bound_replacement = _BoundPromptSequence( + tokenizer=self.tokenizer, + _text=replacement if isinstance(replacement, str) else None, + _token_ids=replacement if isinstance(replacement, list) else None, + ) + + if cache_key is not None: + self._replacement_cache[cache_key] = bound_replacement + + return bound_replacement + + +class ImageSize(NamedTuple): + width: int + height: int + + +class MultiModalDataItems(UserDict[str, list[Any]]): """ - Convert a :class:`MultiModalDataDict` containing single data items - to a :class:`MultiModalMultiDataDict` containing multiple data items - per entry. + As :class:`MultiModalDataDict`, but normalized such that each entry + corresponds to a list. """ - multi_data = dict[str, list[Any]]() + + @property + def image(self) -> list[ImageItem]: + return self["image"] + + @property + def video(self) -> list[VideoItem]: + return self["video"] + + @property + def audio(self) -> list[AudioItem]: + return self["audio"] + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.image[item_idx] + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + +def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems: + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. + """ + multi_data = MultiModalDataItems() for k, v in data.items(): # yapf: disable @@ -266,22 +264,33 @@ def iter_token_matches( token_ids: list[int], match_ids: list[int], ) -> Iterable[_TokenMatch]: - """Yield each occurrence of :code:`match_ids` in :code:`token_ids`.""" + """ + Yield each occurrence of :code:`match_ids` in :code:`token_ids`. + + Note that empty matches are ignored. + """ + prompt_len = len(token_ids) match_len = len(match_ids) - last_end_idx = 0 - for start_idx in range(len(token_ids) - match_len + 1): - if start_idx < last_end_idx: - continue # Exclude overlapping matches + if match_len == 0: + return + start_idx = 0 + while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len + if token_ids[start_idx:end_idx] == match_ids: yield _TokenMatch(start_idx=start_idx, end_idx=end_idx) - last_end_idx = end_idx + + # Exclude overlapping matches + start_idx = end_idx + else: + start_idx += 1 -class _PromptReplacementMatch(ABC, Generic[_T, _S]): - prompt_repl: _BoundPromptReplacement[_T] +@dataclass(repr=False) +class _PromptReplacementMatch(ABC): + prompt_repl: _BoundPromptReplacement @property def modality(self) -> str: @@ -297,19 +306,13 @@ def start_idx(self) -> int: def end_idx(self) -> int: raise NotImplementedError - @property - @abstractmethod - def repl_unit(self) -> _S: - raise NotImplementedError - def __repr__(self) -> str: return (f"{type(self).__name__}(modality={self.modality!r}, " f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") @dataclass(repr=False) -class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]): - prompt_repl: _BoundPromptReplacement[_T] +class _PromptReplacementTokenMatch(_PromptReplacementMatch): match: _TokenMatch @property @@ -320,14 +323,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end_idx - @property - def repl_unit(self) -> list[int]: - return self.prompt_repl.repl_unit.token_ids - @dataclass(repr=False) -class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]): - prompt_repl: _BoundPromptReplacement[_T] +class _PromptReplacementTextMatch(_PromptReplacementMatch): match: re.Match[str] @property @@ -338,20 +336,15 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end() - @property - def repl_unit(self) -> str: - return self.prompt_repl.repl_unit.text - class _PlaceholderInfo(NamedTuple): modality: str start_idx: int - unit: list[int] - unit_count: int + replacement: list[int] @property def length(self) -> int: - return len(self.unit) * self.unit_count + return len(self.replacement) def to_range(self) -> PlaceholderRange: return PlaceholderRange( @@ -362,8 +355,8 @@ def to_range(self) -> PlaceholderRange: def find_token_matches( prompt: list[int], - prompt_repls: Sequence[_BoundPromptReplacement[_T]], -) -> list[_PromptReplacementTokenMatch[_T]]: + prompt_repls: Sequence[_BoundPromptReplacement], +) -> list[_PromptReplacementTokenMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ _PromptReplacementTokenMatch(prompt_repl, match) @@ -374,8 +367,8 @@ def find_token_matches( def find_text_matches( prompt: str, - prompt_repls: Sequence[_BoundPromptReplacement[_T]], -) -> list[_PromptReplacementTextMatch[_T]]: + prompt_repls: Sequence[_BoundPromptReplacement], +) -> list[_PromptReplacementTextMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ _PromptReplacementTextMatch(prompt_repl, match) @@ -385,15 +378,15 @@ def find_text_matches( def _resolve_matches( - prompt: _S, - matches: Sequence[_PromptReplacementMatch[_T, _S]], -) -> list[_PromptReplacementMatch[_T, _S]]: + prompt: _PromptSeq, + matches: Sequence[_PromptReplacementMatch], +) -> list[_PromptReplacementMatch]: """ Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ - = [None] * len(prompt) + seen_matches: list[Optional[_PromptReplacementMatch]] = [None + ] * len(prompt) for match in matches: for idx in range(match.start_idx, match.end_idx): @@ -409,30 +402,34 @@ def _resolve_matches( def _replace_matches( prompt: _S, - matches: Sequence[_PromptReplacementMatch[_T, _S]], - mm_items_by_modality: Mapping[str, list[_T]], - hf_inputs: BatchFeature, + matches: Sequence[_PromptReplacementMatch], + mm_items: MultiModalDataItems, ) -> list[_S]: out_seqs = list[_S]() prev_end_idx = 0 - next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality} + next_idx_by_modality = {modality: 0 for modality in mm_items} for match in _resolve_matches(prompt, matches): modality = match.modality - mm_items = mm_items_by_modality[modality] + modal_items = mm_items[modality] item_idx = next_idx_by_modality[modality] - if item_idx >= len(mm_items): + if item_idx >= len(modal_items): continue start_idx = match.start_idx end_idx = match.end_idx - repl_unit = match.repl_unit + repl_info = match.prompt_repl - repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) + replacement = repl_info.get_replacement(item_idx) + + if isinstance(prompt, str): + repl_seq = replacement.text + out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) + else: + repl_seq = replacement.token_ids + out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) - out_seqs.append(prompt[prev_end_idx:start_idx] + - repl_unit * repl_count) prev_end_idx = end_idx next_idx_by_modality[modality] += 1 @@ -443,92 +440,104 @@ def _replace_matches( def replace_token_matches( prompt: list[int], - matches: Sequence[_PromptReplacementMatch[_T, list[int]]], - mm_items_by_modality: Mapping[str, list[_T]], - hf_inputs: BatchFeature, + matches: Sequence[_PromptReplacementTokenMatch], + mm_items: MultiModalDataItems, ) -> list[int]: """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt - token_id_seqs = _replace_matches( - prompt, - matches, - mm_items_by_modality, - hf_inputs, - ) + token_id_seqs = _replace_matches(prompt, matches, mm_items) return flatten_2d_lists(token_id_seqs) def replace_text_matches( prompt: str, - matches: Sequence[_PromptReplacementMatch[_T, str]], - mm_items_by_modality: Mapping[str, list[_T]], - hf_inputs: BatchFeature, + matches: Sequence[_PromptReplacementTextMatch], + mm_items: MultiModalDataItems, ) -> str: """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt - texts = _replace_matches( - prompt, - matches, - mm_items_by_modality, - hf_inputs, - ) + texts = _replace_matches(prompt, matches, mm_items) return "".join(texts) -def _merge_placeholder_matches( - matches: Iterable[_PromptReplacementTokenMatch], -) -> Iterable[_PromptReplacementTokenMatch]: - current_match = None +def _iter_modality_placeholders( + prompt: list[int], + modality: str, + modality_repls: Sequence[_BoundPromptReplacement], + modal_items: list[Any], +) -> Iterable[_PlaceholderInfo]: + if len(modal_items) == 0: + return - for match in sorted(matches, key=lambda x: x.start_idx): - if current_match is None: - current_match = match - elif (current_match.prompt_repl == match.prompt_repl - and current_match.end_idx == match.start_idx): - current_match = _PromptReplacementTokenMatch( - current_match.prompt_repl, - match=_TokenMatch(current_match.start_idx, match.end_idx), - ) - else: - yield current_match - current_match = match + prompt_len = len(prompt) + item_index = 0 + + start_idx = 0 + while start_idx < prompt_len: + found = False + + for repl_info in modality_repls: + replacement = repl_info.get_replacement(item_index) + repl_tokens = replacement.token_ids + repl_len = len(repl_tokens) + end_idx = start_idx + repl_len + + if repl_len == 0 or end_idx > prompt_len: + continue - if current_match is not None: - yield current_match + if prompt[start_idx:end_idx] == repl_tokens: + yield _PlaceholderInfo( + modality=modality, + start_idx=start_idx, + replacement=repl_tokens, + ) + + item_index += 1 + if item_index >= len(modal_items): + return + + # Exclude overlapping matches + start_idx = end_idx + found = True + break + + if not found: + start_idx += 1 def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt_repls: Sequence[_BoundPromptReplacement], prompt: list[int], - *, - min_unit_count: int = 1, + mm_items: MultiModalDataItems, ) -> Iterable[_PlaceholderInfo]: - """Yield each set of placeholder tokens found in :code:`token_ids`.""" - if min_unit_count <= 0: - raise ValueError("`min_unit_count` must be a positive integer") - - matches = (_PromptReplacementTokenMatch(prompt_repl, match) - for prompt_repl in prompt_repls - if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 - for match in iter_token_matches(prompt, repl_unit)) - - for match in _merge_placeholder_matches(matches): - unit = match.repl_unit - placeholder = _PlaceholderInfo( - modality=match.modality, - start_idx=match.start_idx, - unit=unit, - unit_count=(match.end_idx - match.start_idx) // len(unit), - ) + """ + Yield each set of placeholder tokens found in :code:`prompt`. + + Note that empty matches are ignored. + """ + repls_by_modality = dict(full_groupby_modality(prompt_repls)) + + for modality, modal_items in mm_items.items(): + if modality in repls_by_modality: + yield from _iter_modality_placeholders( + prompt, + modality, + repls_by_modality[modality], + modal_items, + ) + - if placeholder.unit_count >= min_unit_count: - yield placeholder +class ProcessorInputs(NamedTuple): + """Keyword arguments to :meth:`BaseMultiModalProcessor`""" + prompt_text: str + mm_data: MultiModalDataDict + mm_processor_kwargs: Mapping[str, object] class BaseMultiModalProcessor(ABC): @@ -536,52 +545,55 @@ class BaseMultiModalProcessor(ABC): Abstract base class to process multi-modal inputs to be used in vLLM. """ - def __init__( - self, - ctx: InputProcessingContext, - metadata: MultiModalProcessingMetadata, - ) -> None: + def __init__(self, ctx: InputProcessingContext) -> None: super().__init__() self.ctx = ctx - self.metadata = metadata - self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs - or {}) - def _get_hf_processor( + def __call__( self, - **mm_processor_kwargs: Mapping[str, object], - ) -> ProcessorMixin: - # by default, we won't pass any kwargs to the processor initialization + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + return self.apply(prompt, mm_data, mm_processor_kwargs) + + def _get_hf_processor(self) -> ProcessorMixin: + """ + Subclasses can add keyword arguments to this method to accept + additional kwargs from model config or user inputs. + """ return self.ctx.get_hf_processor() def _get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def __call__( + @abstractmethod + def _get_prompt_replacements( self, - prompt: str, - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, + hf_inputs: BatchFeature, mm_processor_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, mm_processor_kwargs) + ) -> list[PromptReplacement]: + """ + Given the original multi-modal items for this modality + and HF-processed data, output the replacements to perform. + + Note: + Even when the HF processor already performs replacement for us, + we still use this replacement information to determine + the placeholder token positions for each multi-modal item. + """ + raise NotImplementedError def _find_placeholders( self, - all_prompt_repls: Sequence[_BoundPromptReplacement[Any]], + all_prompt_repls: Sequence[_BoundPromptReplacement], new_token_ids: list[int], - *, - # To avoid false positives from multi-input when detecting - # whether placeholder tokens have been inserted, in case - # the target sequence is a subset of the replacement tokens - min_unit_count: int = 16, + mm_items: MultiModalDataItems, ) -> list[_PlaceholderInfo]: return list( - iter_placeholders( - all_prompt_repls, - new_token_ids, - min_unit_count=min_unit_count, - )) + iter_placeholders(all_prompt_repls, new_token_ids, mm_items)) def _apply_hf_processor( self, @@ -589,13 +601,7 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - processor_init_kwargs = { - **self.init_mm_processor_kwargs, - **mm_processor_kwargs, - } - hf_processor = self._get_hf_processor(**processor_init_kwargs) + hf_processor = self._get_hf_processor(**mm_processor_kwargs) processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() @@ -615,11 +621,10 @@ def _apply_hf_processor( else: processor_data[k] = v - # filter mm_processor_kwargs used in processor call - mm_processor_kwargs = resolve_mm_processor_kwargs( - self.init_mm_processor_kwargs, - cast(Dict[str, Any], mm_processor_kwargs), + assert callable(hf_processor) + mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs( hf_processor, + mm_processor_kwargs, ) try: @@ -642,26 +647,21 @@ def _apply_hf_processor( def _bind_prompt_replacements( self, - mm_data: MultiModalDataDict, - ) -> list[_BoundPromptReplacement[Any]]: + prompt_repls: list[PromptReplacement], + ) -> list[_BoundPromptReplacement]: tokenizer = self._get_tokenizer() - return [ - prompt_repl.bind(modality, tokenizer) - for modality, metadata in self.metadata.items() - if modality in mm_data for prompt_repl in metadata.prompt_repls - ] + return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] def _apply_prompt_replacements( self, - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_inputs: BatchFeature, token_ids: list[int], - prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt_repls: Sequence[_BoundPromptReplacement], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: tokenizer = self._get_tokenizer() - mm_items = to_multi_format(mm_data) token_matches = find_token_matches(token_ids, prompt_repls) # If the search text does not represent a special token, @@ -682,7 +682,6 @@ def _apply_prompt_replacements( token_ids, token_matches, mm_items, - hf_inputs, ) text = _decode(tokenizer, token_ids) @@ -695,13 +694,13 @@ def _apply_prompt_replacements( text, text_matches, mm_items, - hf_inputs, ) token_ids = _encode(tokenizer, text) matched_repls = [match.prompt_repl for match in text_matches] - placeholders = self._find_placeholders(matched_repls, token_ids) + placeholders = self._find_placeholders(matched_repls, token_ids, + mm_items) return token_ids, text, placeholders @@ -731,12 +730,16 @@ def apply( prompt_ids, = hf_inputs.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs(hf_inputs) - all_prompt_repls = self._bind_prompt_replacements(mm_data) + mm_items = to_multi_format(mm_data) + prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, + mm_processor_kwargs) + all_prompt_repls = self._bind_prompt_replacements(prompt_repls) # If HF processor already inserts placeholder tokens, # there is no need for us to insert them all_placeholders = self._find_placeholders(all_prompt_repls, - prompt_ids) + prompt_ids, mm_items) + if all_placeholders: prompt_text = _decode(tokenizer, prompt_ids) else: @@ -745,7 +748,7 @@ def apply( prompt_text, all_placeholders, ) = self._apply_prompt_replacements( - mm_data, + mm_items, hf_inputs, prompt_ids, all_prompt_repls, @@ -765,13 +768,13 @@ def apply( ) @abstractmethod - def _get_dummy_mm_kwargs( + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: + ) -> ProcessorInputs: """ - Build the input that corresponds to `mm_max_tokens` in - :meth:`get_dummy_data`. + Build the multi-modal portion of the input which, after processing, + results in `mm_max_tokens` in :meth:`get_dummy_data`. """ raise NotImplementedError @@ -784,38 +787,41 @@ def get_dummy_data( # Avoid circular import from vllm.sequence import SequenceData - tokenizer = self._get_tokenizer() - - mm_placeholders = dict[str, _PlaceholderInfo]() - offset = 0 - - for modality, max_tokens in mm_max_tokens.items(): - if max_tokens == 0: - continue - - metadata = self.metadata[modality] - repl = metadata.prompt_repls[0].bind(modality, tokenizer) - repl_token_ids = repl.repl_unit.token_ids - - placeholders = _PlaceholderInfo( - modality=modality, - start_idx=offset, - unit=repl_token_ids, - unit_count=max_tokens // len(repl_token_ids), - ) - - mm_placeholders[modality] = placeholders - offset += placeholders.length + processor_inputs = self._get_dummy_mm_inputs(mm_counts) + mm_inputs = self.apply(*processor_inputs) + + prompt_token_ids = mm_inputs["prompt_token_ids"] + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = dict[str, int]() + for modality, placeholders in placeholders_by_modality.items(): + num_placeholders = sum(item["length"] for item in placeholders) + max_tokens = mm_max_tokens[modality] + + if num_placeholders != max_tokens: + logger.warning( + "The processed dummy data has a total of %d placeholder " + "tokens for the '%s' modality, which is not the expected " + "%d tokens.", num_placeholders, modality, max_tokens) + + total_placeholders_by_modality[modality] = num_placeholders + + total_len = len(prompt_token_ids) + if total_len > seq_len: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain multi-modal " + "inputs to fail during inference, even when the input text is " + "short. To avoid this, you should increase `max_model_len`, " + "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, + total_len, total_placeholders_by_modality) - prompt_token_ids = flatten_2d_lists( - [p.unit * p.unit_count for p in mm_placeholders.values()]) prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) return DummyData( seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=self._get_dummy_mm_kwargs(mm_counts), - multi_modal_placeholders={ - modality: [p.to_range()] - for modality, p in mm_placeholders.items() - }, + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=placeholders_by_modality, ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 6ab6c0fe2f12e..03f8814a95356 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -299,9 +299,9 @@ def register_processor( """ def wrapper(model_cls: N) -> N: - if model_cls in self._processor_factories: + if self._processor_factories.contains(model_cls, strict=True): logger.warning( - "Model class %s already has an input mapper " + "Model class %s already has a multi-modal processor " "registered to %s. It is overwritten by the new one.", model_cls, self) diff --git a/vllm/utils.py b/vllm/utils.py index fbc3ef7fa7f89..45e682ac15782 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1370,8 +1370,8 @@ def supports_kw( def resolve_mm_processor_kwargs( - init_kwargs: Optional[Dict[str, Any]], - inference_kwargs: Optional[Dict[str, Any]], + init_kwargs: Optional[Mapping[str, object]], + inference_kwargs: Optional[Mapping[str, object]], callable: Callable[..., object], allow_var_kwargs: bool = False, ) -> Dict[str, Any]: @@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs( def get_allowed_kwarg_only_overrides( callable: Callable[..., object], - overrides: Optional[Dict[str, Any]], + overrides: Optional[Mapping[str, object]], allow_var_kwargs: bool = False, ) -> Dict[str, Any]: """ @@ -1524,9 +1524,15 @@ def __getitem__(self, key: Type[T]) -> _V: raise KeyError(key) def __contains__(self, key: object) -> bool: + return self.contains(key) + + def contains(self, key: object, *, strict: bool = False) -> bool: if not isinstance(key, type): return False + if strict: + return key in self.data + return any(cls in self.data for cls in key.mro())