Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Refactor Ultravox to use merged input processor #11198

Merged
merged 27 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4c3e5d5
refactor ultravox process
Isotr0py Dec 10, 2024
d160560
fix processor inputs
Isotr0py Dec 10, 2024
91384bf
fix ultravox processor
Isotr0py Dec 13, 2024
6d31c3d
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 13, 2024
782bd61
fix placeholder padding
Isotr0py Dec 13, 2024
5350918
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 14, 2024
57c7ec9
add comments
Isotr0py Dec 14, 2024
c1a9cef
update example
Isotr0py Dec 14, 2024
89416a8
code format
Isotr0py Dec 14, 2024
9693691
remove unused code
Isotr0py Dec 14, 2024
8254384
Merge branch 'main' into ultravox-refactor
Isotr0py Dec 15, 2024
e0ef4bc
Merge branch 'vllm-project:main' into ultravox-refactor
Isotr0py Dec 15, 2024
08a3422
clean up
Isotr0py Dec 15, 2024
d72fe45
refactor
Isotr0py Dec 15, 2024
0b8aa47
code format
Isotr0py Dec 15, 2024
d5b7cf7
fix prompt replacement
Isotr0py Dec 15, 2024
980c731
code format
Isotr0py Dec 15, 2024
5cb6362
fix audio_token truncation
Isotr0py Dec 15, 2024
0854a67
fix mm_data
Isotr0py Dec 15, 2024
146fc63
fix audio_token_len and online inference
Isotr0py Dec 15, 2024
342048c
rename
Isotr0py Dec 15, 2024
daba237
Update vllm/model_executor/models/ultravox.py
Isotr0py Dec 15, 2024
7813d47
clean up
Isotr0py Dec 15, 2024
6e7b138
handle no audio data
Isotr0py Dec 15, 2024
ca58f8b
Update vllm/model_executor/models/ultravox.py
Isotr0py Dec 15, 2024
e1fdd36
cleanup replacement
Isotr0py Dec 15, 2024
8ef0b23
fix audio entrypoint and pp test
Isotr0py Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
'role': 'user',
'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def iter_params(self, model_name: str):
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder]
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def server():
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

AudioTuple = Tuple[np.ndarray, int]

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
Expand Down Expand Up @@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _placeholder_str(self, modality: ModalityStr,
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
return "<|audio|>"
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
Expand Down
244 changes: 104 additions & 140 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,39 @@

import math
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)

_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25


Expand Down Expand Up @@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):

return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}


def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}


def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)

return DummyData(seq_data, mm_dict, ranges)


def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

if len(data) == 0:
return MultiModalKwargs()

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")

(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor

def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
Expand All @@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}

minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))

single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]

# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))

return MultiModalKwargs({"audio_features": audio_features})


def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)

audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]

# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())

return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)

if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement

def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len

return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]

feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]

audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate

audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}

return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)


class StackAudioFrames(nn.Module):
Expand Down Expand Up @@ -332,11 +298,9 @@ def forward(
return hidden_states


@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
Loading
Loading