diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index db489af7ac475..b0a1104546186 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -2,7 +2,8 @@ from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final +from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast, + final) import numpy as np import torch @@ -11,7 +12,7 @@ from transformers import BatchFeature from typing_extensions import NotRequired, TypeAlias -from vllm.utils import JSONTree, is_list_of, json_map_leaves +from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves _T = TypeVar("_T") @@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: @dataclass(frozen=True) -class MultiModalFieldItem: - """ - Contains metadata and data in :class:`MultiModalKwargs` - corresponding to a data item in :class:`MultiModalDataItems`. - """ +class MultiModalFieldElem: + """Contains metadata and data of an item in :class:`MultiModalKwargs`.""" field: "BaseMultiModalField" data: NestedTensors @@ -186,34 +184,34 @@ class BaseMultiModalField(ABC): def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: raise NotImplementedError - def _build_item(self, data: NestedTensors) -> MultiModalFieldItem: - return MultiModalFieldItem(self, data) + def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem: + return MultiModalFieldElem(self, data) - def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem: - """Merge multiple instances of :class:`MultiModalFieldItem` together.""" + def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem: + """Merge multiple instances of :class:`MultiModalFieldElem` together.""" fields = [item.field for item in batch] if len(set(fields)) > 1: raise ValueError(f"Cannot merge different {fields=}") data = self._reduce_data([item.data for item in batch]) - return self._build_item(data) + return self._build_elem(data) @dataclass(frozen=True) class MultiModalBatchedField(BaseMultiModalField): """ - A :class:`BaseMultiModalField` implementation where an item is obtained by - directly indexing into the first dimension of the underlying data. + A :class:`BaseMultiModalField` implementation where an element in the batch + is obtained by indexing into the first dimension of the underlying data. """ - def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]: - return [self._build_item(item) for item in batch] + def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]: + return [self._build_elem(item) for item in batch] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): first_shape = batch[0].shape - if all(item.shape == first_shape for item in batch): + if all(elem.shape == first_shape for elem in batch): return torch.stack(batch) return batch @@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: @dataclass(frozen=True) class MultiModalFlatField(BaseMultiModalField): """ - A :class:`BaseMultiModalField` implementation where an item is obtained by - slicing along the first dimension of the underlying data. + A :class:`BaseMultiModalField` implementation where an element in the batch + is obtained by slicing along the first dimension of the underlying data. """ - def build_items( + def build_elems( self, batch: NestedTensors, slices: Sequence[slice], - ) -> list[MultiModalFieldItem]: - return [self._build_item(batch[slice_]) for slice_ in slices] + ) -> list[MultiModalFieldElem]: + return [self._build_elem(batch[slice_]) for slice_ in slices] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): first_shape = batch[0].shape - if all(item.shape[1:] == first_shape[1:] for item in batch): + if all(elem.shape[1:] == first_shape[1:] for elem in batch): return torch.concat(batch) - return [elem for item in batch for elem in item] + return [e for elem in batch for e in elem] class MultiModalFieldConfig: @@ -267,115 +265,111 @@ def __init__( ) -> None: super().__init__() - self._field_cls = field_cls - self._modality = modality - self._field_config = field_config + self.field_cls = field_cls + self.modality = modality + self.field_config = field_config - def build_items( + def build_elems( self, key: str, batch: NestedTensors, - ) -> list[MultiModalFieldItem]: - field = self._field_cls(key=key, modality=self._modality) - return field.build_items(batch, **self._field_config) # type: ignore + ) -> Sequence[MultiModalFieldElem]: + field = self.field_cls(key=key, modality=self.modality) + return field.build_elems(batch, **self.field_config) # type: ignore -class MultiModalKwargs(UserDict[str, NestedTensors]): +class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): + """ + A collection of :class:`MultiModalFieldElem` + corresponding to a data item in :class:`MultiModalDataItems`. """ - A dictionary that represents the keyword arguments to - :meth:`~torch.nn.Module.forward`. - The metadata :code:`items_by_key` defines how to split batched keyword - arguments corresponding to each data item in :class:`MultiModalDataItems`: + @staticmethod + def from_elems(elems: Sequence[MultiModalFieldElem]): + return MultiModalKwargsItem({elem.field.key: elem for elem in elems}) - - For a keyword argument, we can access the :code:`i` th item in the batch - via :code:`items_by_key[key][i]`. - - We can gather the keyword arguments belonging to a modality by finding - the keys with items that belong to that modality, then accessing - the :code:`i` th item in the batch for each such key. + @property + def modality(self) -> str: + modalities = {elem.field.modality for elem in self.data.values()} + assert len(modalities) == 1, f"Found different modalities={modalities}" + return next(iter(modalities)) - Example: - .. code-block:: python - - # All items belong to the "image" modality - items_by_key={ - "pixel_values": [a, b, c, d], # "image" modality - "image_grid_thw": [e, f, g, h], # "image" modality - "pixel_values_video": [h, i, j], # "video" modality - "video_grid_thw": [k, l, m], # "video" modality - } +# NOTE: UserDict is for V0 compatibility. +# V1 should access individual items via `get_item`. +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. - - The keyword arguments belonging to the first image are - :code:`{"pixel_values": a, "image_grid_thw": e}`. - - The keyword arguments belonging to the second video are - :code:`{"pixel_values_video": i, "video_grid_thw": l}`. + The metadata :code:`items` enables us to obtain the keyword arguments + corresponding to each data item in :class:`MultiModalDataItems`, via + :meth:`get_item` and :meth:`get_items`. """ @staticmethod def from_hf_inputs( hf_inputs: BatchFeature, config_by_key: Mapping[str, MultiModalFieldConfig], - *, - enable_sanity_checks: bool = False, ): # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` # We assume that those fields are not used in vLLM - items_by_key = { - key: config.build_items(key, batch) - for key, config in config_by_key.items() - if (batch := hf_inputs.get(key)) is not None - } - - return MultiModalKwargs.from_items_by_key( - items_by_key, - enable_sanity_checks=enable_sanity_checks, - ) + elems_by_key = dict[str, Sequence[MultiModalFieldElem]]() + keys_by_modality = defaultdict[str, set[str]](set) + for key, config in config_by_key.items(): + batch = hf_inputs.get(key) + if batch is not None: + elems = config.build_elems(key, batch) + if len(elems) > 0: + elems_by_key[key] = elems + keys_by_modality[config.modality].add(key) + + items = list[MultiModalKwargsItem]() + for modality, keys in keys_by_modality.items(): + elems_in_modality = {k: elems_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} + + if len(set(batch_sizes.values())) > 1: + raise ValueError( + f"Cannot merge different batch sizes for {modality=}! " + f"Found: {batch_sizes=}") + + batch_size = next(iter(batch_sizes.values())) + for item_idx in range(batch_size): + elems = [v[item_idx] for v in elems_in_modality.values()] + items.append(MultiModalKwargsItem.from_elems(elems)) + + return MultiModalKwargs.from_items(items) @staticmethod - def from_items_by_key( - items_by_key: Mapping[str, list[MultiModalFieldItem]], - *, - enable_sanity_checks: bool = False, - ) -> "MultiModalKwargs": + def from_items(items: Sequence[MultiModalKwargsItem]): + """Construct a new :class:`MultiModalKwargs` from multiple items.""" + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for item in items: + for key, elem in item.items(): + elems_by_key[key].append(elem) + data = { - key: items[0].field.reduce(items).data - for key, items in items_by_key.items() if len(items) > 0 + key: elems[0].field.reduce(elems).data + for key, elems in elems_by_key.items() if len(elems) > 0 } - return MultiModalKwargs(data, - items_by_key=items_by_key, - enable_sanity_checks=enable_sanity_checks) + return MultiModalKwargs(data, items=items) def __init__( self, data: Mapping[str, NestedTensors], *, - items_by_key: Mapping[str, list[MultiModalFieldItem]] = {}, - enable_sanity_checks: bool = False, + items: Optional[Sequence[MultiModalKwargsItem]] = None, ) -> None: super().__init__(data) - # Shallow copy to avoid footgun in case a defaultdict is passed in - self._items_by_key = dict(items_by_key) + items_by_modality = full_groupby(items or [], key=lambda x: x.modality) + self._items_by_modality = dict(items_by_modality) - keys_by_modality = defaultdict[str, set[str]](set) - for key, items in items_by_key.items(): - for item in items: - keys_by_modality[item.field.modality].add(key) - - self._keys_by_modality = dict(keys_by_modality) - - if enable_sanity_checks: - for modality, keys in keys_by_modality.items(): - items_in_modality = {k: items_by_key[k] for k in keys} - batch_sizes = {k: len(v) for k, v in items_in_modality.items()} - batch_size = next(iter(batch_sizes.values()), 0) - assert all(bs == batch_size - for bs in batch_sizes.values()), dict( - modality=modality, - batch_sizes=batch_sizes, - items_by_key=items_by_key) + @property + def modalities(self): + return self._items_by_modality.keys() @staticmethod def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: @@ -452,58 +446,44 @@ def as_kwargs( def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - if self._items_by_key != other._items_by_key: + if self._items_by_modality != other._items_by_modality: return False ks = self.keys() return (ks == other.keys() and all(nested_tensors_equal(self[k], other[k]) for k in ks)) - def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: - return self._items_by_key[key][item_index] + def _validate_modality(self, method_name: str, modality: str) -> None: + if not self._items_by_modality: + raise RuntimeError( + f"`{method_name}` is not supported when " + "MultiModalKwargs is not initialized with `items`") - def get_items_by_modality( - self, - modality: str, - item_index: int, - ) -> Mapping[str, MultiModalFieldItem]: - """ - Get the keyword arguments corresponding to an item identified by - its modality and index. - """ - if modality not in self._keys_by_modality: - available_modalities = set(self._keys_by_modality.keys()) + if modality not in self._items_by_modality: + available_modalities = set(self._items_by_modality.keys()) raise KeyError(f"Modality {modality!r} not found. " f"Available modalities: {available_modalities}") - keys_to_gather = self._keys_by_modality[modality] + def get_item_count(self, modality: str) -> int: + """Get the number of items belonging to a modality.""" + self._validate_modality("get_item_count", modality) + return len(self._items_by_modality[modality]) - return { - key: self.get_item(key, item_index) - for key in keys_to_gather if key in self - } + def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: + """ + Get the keyword arguments corresponding to an item identified by + its modality and index. + """ + self._validate_modality("get_item", modality) + return self._items_by_modality[modality][item_index] - @staticmethod - def from_items_by_modality( - items_by_modality: Mapping[str, list[Mapping[str, - MultiModalFieldItem]]], - *, - enable_sanity_checks: bool = False, - ) -> "MultiModalKwargs": + def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: """ - Construct a new :class:`MultiModalKwargs` from multiple items returned - by :meth:`get_fields_by_modality`. + Get the keyword arguments corresponding to each item belonging to + a modality. """ - items_by_key = defaultdict[str, list[MultiModalFieldItem]](list) - for fields in items_by_modality.values(): - for field in fields: - for k, v in field.items(): - items_by_key[k].append(v) - - return MultiModalKwargs.from_items_by_key( - items_by_key, - enable_sanity_checks=enable_sanity_checks, - ) + self._validate_modality("get_items", modality) + return self._items_by_modality[modality] MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 76475ddda81f4..64cdacfb4c574 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -20,8 +20,8 @@ from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs, - PlaceholderRange) + MultiModalInputsV2, MultiModalKwargs, + MultiModalKwargsItem, PlaceholderRange) from .parse import MultiModalDataItems, MultiModalDataParser logger = init_logger(__name__) @@ -496,8 +496,7 @@ def __init__(self, capacity: int) -> None: # DEBUG: Set to None to disable self.debug_cache_hit_ratio_steps: Optional[int] = None - self._cache = LRUCache[str, Mapping[str, - MultiModalFieldItem]](capacity) + self._cache = LRUCache[str, MultiModalKwargsItem](capacity) def _maybe_log_cache_stats(self) -> None: steps = self.debug_cache_hit_ratio_steps @@ -565,7 +564,7 @@ def get( modality: str, input_item: object, input_kwargs: Mapping[str, object], - ) -> Optional[Mapping[str, MultiModalFieldItem]]: + ) -> Optional[MultiModalKwargsItem]: """ Get a processed multi-modal item from the cache according to its dependencies, including: @@ -588,7 +587,7 @@ def put( modality: str, input_item: object, input_kwargs: Mapping[str, object], - output_kwargs: Mapping[str, MultiModalFieldItem], + output_kwargs: MultiModalKwargsItem, ) -> None: """ Put a processed multi-modal item into the cache @@ -784,7 +783,6 @@ def _apply_hf_processor( mm_kwargs = MultiModalKwargs.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - enable_sanity_checks=self.enable_sanity_checks, ) return prompt_ids, mm_kwargs @@ -846,7 +844,7 @@ def _cached_apply_hf_processor( hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) - mm_maybe_cached_field_items = { + mm_maybe_cached_kw_items = { modality: [ cache.get(model_id, modality, item, hf_processor_mm_kwargs) for item in items @@ -855,8 +853,9 @@ def _cached_apply_hf_processor( } mm_missing_idxs = { - modality: [idx for idx, out in enumerate(fields) if out is None] - for modality, fields in mm_maybe_cached_field_items.items() + modality: + [idx for idx, item in enumerate(kw_items) if item is None] + for modality, kw_items in mm_maybe_cached_kw_items.items() } mm_missing_data = { modality: [mm_data_items[modality][idx] for idx in idxs] @@ -875,14 +874,11 @@ def _cached_apply_hf_processor( for modality in mm_missing_data_items } - mm_merged_field_items = dict[str, list[Mapping[str, - MultiModalFieldItem]]]() - for modality, modal_items_lst in mm_maybe_cached_field_items.items(): - merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]() - - for idx, modal_items in enumerate(modal_items_lst): - if modal_items is None: - modal_items = mm_missing_kwargs.get_items_by_modality( + merged_kw_items = list[MultiModalKwargsItem]() + for modality, kw_items in mm_maybe_cached_kw_items.items(): + for idx, kw_item in enumerate(kw_items): + if kw_item is None: + kw_item = mm_missing_kwargs.get_item( modality, mm_missing_next_idx[modality], ) @@ -892,14 +888,12 @@ def _cached_apply_hf_processor( modality, mm_data_items[modality][idx], hf_processor_mm_kwargs, - modal_items, + kw_item, ) mm_missing_next_idx[modality] += 1 - merged_modal_items_lst.append(modal_items) - - mm_merged_field_items[modality] = merged_modal_items_lst + merged_kw_items.append(kw_item) if self.enable_sanity_checks: mm_missing_counts = mm_missing_data_items.get_all_counts() @@ -909,10 +903,7 @@ def _cached_apply_hf_processor( mm_missing_next_idx=mm_missing_next_idx, mm_missing_counts=mm_missing_counts) - mm_kwargs = MultiModalKwargs.from_items_by_modality( - mm_merged_field_items, - enable_sanity_checks=self.enable_sanity_checks, - ) + mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) if self.enable_sanity_checks: mm_item_counts = mm_data_items.get_all_counts() @@ -920,7 +911,7 @@ def _cached_apply_hf_processor( for modality, item_count in mm_item_counts.items(): for item_idx in range(item_count): try: - mm_kwargs.get_items_by_modality(modality, item_idx) + mm_kwargs.get_item(modality, item_idx) except Exception as e: # Make it easy to set a breakpoint in the debugger raise e diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5b5a5a61cea7d..905d3d1fc3e1c 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -113,15 +113,27 @@ def process_inputs( # For merged preprocessor, mm_data is already mm_inputs precomputed_mm_inputs = None - if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): - precomputed_mm_inputs = [decoder_inputs.multi_modal_data] + decoder_mm_data = decoder_inputs.multi_modal_data + if isinstance(decoder_mm_data, MultiModalKwargs): + # The output of merged multi-modal processor (`decoder_mm_data`) + # contains the kwargs for all items from all modalities. + # This code separates them so that there is one set of kwargs + # per item per modality. + precomputed_mm_inputs = [ + MultiModalKwargs.from_items([item]) + for modality in decoder_mm_data.modalities + for item in decoder_mm_data.get_items(modality) + ] # Apply MM mapper mm_inputs = None - if len(decoder_inputs.multi_modal_data) > 0: + if len(decoder_mm_data) > 0: mm_inputs = self.mm_input_mapper_client.process_inputs( - decoder_inputs.multi_modal_data, mm_hashes, - decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) + decoder_mm_data, + mm_hashes, + decoder_inputs.mm_processor_kwargs, + precomputed_mm_inputs, + ) return EngineCoreRequest( request_id,