Skip to content

Commit

Permalink
Revert changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Flechman committed Jan 26, 2025
1 parent 46c142f commit 41c423a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 152 deletions.
93 changes: 91 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SiglipVisionConfig)
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -35,8 +36,8 @@

from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFMultiModalProcessor, PixtralHFProcessingInfo,
PixtralHFVisionModel)
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -262,6 +263,94 @@ def _get_mm_fields_config(
)


class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):

def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)


class PixtralHFMultiModalProcessor(
BaseMultiModalProcessor[PixtralHFProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)

# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))

processed_outputs["pixel_values"] = pixel_values[0]

return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()

image_break_id = vocab[processor.image_break_token]
image_token_id = hf_config.image_token_index
image_end_id = vocab[processor.image_end_token]

vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)

def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)

tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id

return tokens

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]


def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
Expand Down
Loading

0 comments on commit 41c423a

Please sign in to comment.