Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision #11685

Merged
merged 55 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
022c6b4
initial
ywang96 Jan 1, 2025
43fdf45
fix llava ov
ywang96 Jan 1, 2025
e0fb002
iterate
ywang96 Jan 1, 2025
a9b9757
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 2, 2025
b45010b
revert padding tensor
ywang96 Jan 2, 2025
d83e25e
simplify
ywang96 Jan 2, 2025
d13b0f7
comment
ywang96 Jan 2, 2025
7d1f19a
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 2, 2025
6959ec0
simplify and doc
ywang96 Jan 2, 2025
ba071c6
refactor logic
ywang96 Jan 3, 2025
ba2f399
format
ywang96 Jan 3, 2025
ff4cdea
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 3, 2025
2eebfd9
switch order
ywang96 Jan 3, 2025
20dd84d
refactor
ywang96 Jan 3, 2025
34ec194
typing
ywang96 Jan 3, 2025
9f19629
hasher
ywang96 Jan 3, 2025
66484aa
consolidate mm hasher
ywang96 Jan 4, 2025
1423f5f
typing
ywang96 Jan 4, 2025
ba17100
Merge branch 'vllm-project:main' into v1-llava-ov
ywang96 Jan 4, 2025
b3c41ce
Merge branch 'main' into v1-llava-ov
ywang96 Jan 5, 2025
14481fd
fix length check
ywang96 Jan 5, 2025
6f435cf
update profiling
ywang96 Jan 5, 2025
16e5b04
update dummy data for llava-ov
ywang96 Jan 5, 2025
612880b
preserve modality order
ywang96 Jan 5, 2025
3022754
format
ywang96 Jan 5, 2025
20d6a67
simplify
ywang96 Jan 5, 2025
3dd2db2
typo
ywang96 Jan 5, 2025
5ce6f7a
clarify
ywang96 Jan 5, 2025
4113e51
add test
ywang96 Jan 5, 2025
3ca30fc
fix test
ywang96 Jan 5, 2025
ef8c6d1
add note
ywang96 Jan 5, 2025
87f4216
Merge branch 'v1-llava-ov' of https://github.com/ywang96/vllm into v1…
ywang96 Jan 5, 2025
bc1debd
comment
ywang96 Jan 5, 2025
56a7ef0
typo
ywang96 Jan 5, 2025
568a586
rename
ywang96 Jan 5, 2025
6ca99a3
remove redundant constants
ywang96 Jan 5, 2025
6c8ff3b
update interface with note
ywang96 Jan 5, 2025
293b3fe
update doc
ywang96 Jan 5, 2025
14482bf
address review comments
ywang96 Jan 6, 2025
eeee402
use namedtuple
ywang96 Jan 6, 2025
7f4815e
add comment
ywang96 Jan 6, 2025
1ba40e9
update
ywang96 Jan 6, 2025
2eb4cf1
format
ywang96 Jan 6, 2025
fe71431
format
ywang96 Jan 6, 2025
1a7b39c
remove unneeded check
ywang96 Jan 6, 2025
61991b6
Merge branch 'main' into v1-llava-ov
ywang96 Jan 6, 2025
ceec26e
remove unused import
ywang96 Jan 6, 2025
7879952
restrict mm_hash to V1
ywang96 Jan 6, 2025
72ae769
fix test and reorder code for readability
ywang96 Jan 6, 2025
48811b6
typo
ywang96 Jan 6, 2025
b31fd4f
format
ywang96 Jan 6, 2025
be54b2c
Fix dummy requests
DarkLight1337 Jan 6, 2025
b2cbc5a
Pass sanity check
DarkLight1337 Jan 6, 2025
3400d07
format
DarkLight1337 Jan 6, 2025
2461f0f
Merge branch 'main' into v1-llava-ov
DarkLight1337 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# Ref: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md?plain=1#L14
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 2304

# For profile run
_MAX_FRAMES_PER_VIDEO = 16
Expand Down Expand Up @@ -366,9 +366,11 @@ def input_processor_for_llava_onevision(ctx: InputContext,
and "image" not in multi_modal_data):
return inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, inputs)
inputs = input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, inputs)
else:
return inputs

msg = "Unsupported multi data type"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -832,21 +834,18 @@ def get_multimodal_embeddings(
if not modalities:
return None

# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
# The result multimoal_embeddings is tuple of tensors, with each
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings.append((vision_embeddings, "image"))
multimodal_embeddings += tuple(vision_embeddings)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings.append((video_embeddings, "video"))
multimodal_embeddings += tuple(video_embeddings)

return multimodal_embeddings

Expand All @@ -858,15 +857,9 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.image_token_index)
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.video_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds

def forward(
Expand Down
50 changes: 49 additions & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
from pathlib import Path
from typing import Optional, TypeVar, Union
from typing import TYPE_CHECKING, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse

import numpy as np
Expand All @@ -25,6 +25,9 @@

_M = TypeVar("_M")

if TYPE_CHECKING:
from ..multimodal import MultiModalPlaceholderDict


class MediaConnector:

Expand Down Expand Up @@ -437,3 +440,48 @@ def consecutive_placeholder_ranges(
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]


def merge_and_sort_placeholders_from_modalities(
mm_positions: "MultiModalPlaceholderDict"
) -> tuple[list[str], list[PlaceholderRange]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.

Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")

Returns:
list[str]: Sorted list of involved modalities.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
"""

modalities = list(mm_positions.keys())
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

# For single modality, its placeholder ranges are already sorted.
if len(modalities) == 1:
return modalities, list(mm_positions[modalities[0]])

placeholder_lists_with_modality = [(modality, mm_positions[modality])
for modality in modalities
if modality in mm_positions]
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

sorted_lists_with_modality = sorted(placeholder_lists_with_modality,
key=lambda x: x[1][0]['offset'])

# Verify if the sorted order avoids interleaving
merged: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_lists_with_modality:
if merged and placeholder_list[0]['offset'] < merged[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged.extend(placeholder_list)

# Return the order of modalities and the merged placeholder ranges
return [modality for modality, _ in sorted_lists_with_modality], merged
27 changes: 25 additions & 2 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import merge_and_sort_placeholders_from_modalities
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
Expand Down Expand Up @@ -51,15 +52,37 @@ def __init__(
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
sorted_modalities, sorted_mm_positions = merge_and_sort_placeholders_from_modalities( # noqa: E501
mm_positions)
self.mm_positions = sorted_mm_positions
else:
sorted_modalities = []
self.mm_positions = []

# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
# NOTE: We only need to sort multimodal kwargs when there
# are multiple modalities involved.
if len(sorted_modalities) > 1:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}

# Sanity check to make sure each multimodal input
# has only one modality key.
for mm_input in self.inputs.multi_modal_inputs:
assert len(mm_input.modalities) == 1

# Sort MultiModalKwags to match sorted_mm_positions
self.inputs.multi_modal_inputs.sort(
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
ywang96 marked this conversation as resolved.
Show resolved Hide resolved

self.mm_inputs = self.inputs.multi_modal_inputs

assert len(self.mm_inputs) == len(self.mm_positions)
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes

# Cache the computed kv block hashes of the request to avoid
Expand Down
Loading