From 4c3e5d59a9cf1572bc78e64bf0ebc5cb0494c3f4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 10 Dec 2024 15:19:38 +0800 Subject: [PATCH 01/23] refactor ultravox process Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 170 ++++++------------------- 1 file changed, 37 insertions(+), 133 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ea1e5401d42c0..9681d467b4b92 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,7 +4,7 @@ import math from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union, cast) + TypedDict, Union) import numpy as np import torch @@ -16,8 +16,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -25,12 +24,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, NestedTensors) -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ModalityProcessingMetadata, + MultiModalProcessingMetadata, + PromptReplacement) +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, @@ -72,23 +72,6 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) -def dummy_seq_data_for_ultravox( - ctx: InputContext, - seq_len: int, - audio_count: int, -): - audio_length = min(get_ultravox_max_audio_tokens(ctx), - seq_len // audio_count) - - return SequenceData.from_prompt_token_counts( - (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), - (0, seq_len - audio_length * audio_count)), { - "audio": - consecutive_placeholder_ranges(num_items=audio_count, - item_size=audio_length) - } - - def dummy_audio_for_ultravox( ctx: InputContext, audio_count: int, @@ -98,120 +81,43 @@ def dummy_audio_for_ultravox( return {"audio": [audio_and_sr] * audio_count} -def dummy_data_for_ultravox( - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], -): - audio_count = mm_counts["audio"] - seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) - mm_dict = dummy_audio_for_ultravox(ctx, audio_count) - - return DummyData(seq_data, mm_dict, ranges) - +def dummy_mm_kwargs_for_ultravox(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): -def input_mapper_for_ultravox(ctx: InputContext, data: object): - if not isinstance(data, list): - data = [data] + data = dummy_audio_for_ultravox(ctx=ctx, audio_count=mm_counts["audio"]) - if len(data) == 0: - return MultiModalKwargs() + hf_processor = ctx.get_hf_processor() + audio_processor = hf_processor.audio_processor # type: ignore + hf_inputs = audio_processor(audio=data['audio'], return_tensors="pt") - # If the audio inputs are embeddings, no need for preprocessing - if is_list_of(data, torch.Tensor, check="all"): - return MultiModalKwargs({"audio_embeds": data}) + return MultiModalKwargs(**hf_inputs) - audio_features = [] - for audio_input in data: - if not isinstance(audio_input, tuple): - raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}") - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) - feature_extractor = whisper_feature_extractor(ctx) +def create_metadata_for_ultravox( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: + return { + "audio": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[_AUDIO_PLACEHOLDER_TOKEN], + repl_unit=[_AUDIO_PLACEHOLDER_TOKEN], + repl_count=get_ultravox_max_audio_tokens(ctx)), + ]), + } - if sr != feature_extractor.sampling_rate: - try: - import librosa - except ImportError as exc: - raise ImportError( - "Please install vllm[audio] for audio support.") from exc - audio = librosa.resample(audio, - orig_sr=sr, - target_sr=feature_extractor.sampling_rate) - sr = feature_extractor.sampling_rate - minimum_audio_length = feature_extractor.n_fft // 2 + 1 - if len(audio) < minimum_audio_length: - # Not enough audio; pad it. - audio = np.pad(audio, (0, minimum_audio_length - len(audio))) +class UltravoxProcessor(BaseMultiModalProcessor): - single_audio_features = feature_extractor( - audio, sampling_rate=sr, padding="longest", - return_tensors="pt")["input_features"] - - # Remove the batch dimension because we're wrapping it in a list. - audio_features.append(single_audio_features.squeeze(0)) - - return MultiModalKwargs({"audio_features": audio_features}) - - -def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "audio" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "audio" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_ultravox(ctx), + ) - feature_extractor = whisper_feature_extractor(ctx) - audios = multi_modal_data["audio"] - if not isinstance(audios, list): - audios = [audios] - - audio_token_counts = [] - for audio in audios: - if isinstance(audio, torch.Tensor): - audio_num_tokens = audio.shape[1] - audio_token_counts.append(audio_num_tokens) - else: - audio_data, sample_rate = audio - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - (feature_extractor.hop_length - 1)) / - feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) - audio_token_counts.append(audio_num_tokens) - - tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, - repeat_count=audio_token_counts, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"audio": ranges}) + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_ultravox(self.ctx, mm_counts) class StackAudioFrames(nn.Module): @@ -332,11 +238,9 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_ultravox_max_audio_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) -@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) +@MULTIMODAL_REGISTRY.register_processor(UltravoxProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): From d16056018133c4d493784ff088d5790d987aa0bd Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 10 Dec 2024 17:24:10 +0800 Subject: [PATCH 02/23] fix processor inputs Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 12 ++++++--- vllm/multimodal/processing.py | 35 ++++++++++++++++---------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9681d467b4b92..cf2b95671c352 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -77,8 +77,8 @@ def dummy_audio_for_ultravox( audio_count: int, ): feature_extractor = whisper_feature_extractor(ctx) - audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) - return {"audio": [audio_and_sr] * audio_count} + audio = np.array([0.0] * feature_extractor.chunk_length) + return {"audio": [audio] * audio_count} def dummy_mm_kwargs_for_ultravox(ctx: InputProcessingContext, @@ -88,7 +88,7 @@ def dummy_mm_kwargs_for_ultravox(ctx: InputProcessingContext, hf_processor = ctx.get_hf_processor() audio_processor = hf_processor.audio_processor # type: ignore - hf_inputs = audio_processor(audio=data['audio'], return_tensors="pt") + hf_inputs = audio_processor(data['audio'], return_tensors="pt") return MultiModalKwargs(**hf_inputs) @@ -113,6 +113,12 @@ def __init__(self, ctx: InputProcessingContext) -> None: metadata=create_metadata_for_ultravox(ctx), ) + def _get_processor_data(self, mm_data): + processor_data, passthrough_data = super()._get_processor_data(mm_data) + if "audios" in processor_data: + processor_data["audio"] = processor_data.pop("audios") + return processor_data, passthrough_data + def _get_dummy_mm_kwargs( self, mm_counts: Mapping[str, int], diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 922c83b6fd8a9..31b2437afae45 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,7 +3,7 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, +from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, Tuple, TypeVar, Union, cast) import torch @@ -583,20 +583,10 @@ def _find_placeholders( min_unit_count=min_unit_count, )) - def _apply_hf_processor( + def _get_processor_data( self, - prompt: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], - ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - processor_init_kwargs = { - **self.init_mm_processor_kwargs, - **mm_processor_kwargs, - } - hf_processor = self._get_hf_processor(**processor_init_kwargs) - + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() for k, v in mm_data.items(): @@ -614,6 +604,25 @@ def _apply_hf_processor( processor_data[f"{k}s"] = v else: processor_data[k] = v + return processor_data, passthrough_data + + def _apply_hf_processor( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: + # some mm_processor_kwargs may be used in processor initialization + # instead of processor call + processor_init_kwargs = { + **self.init_mm_processor_kwargs, + **mm_processor_kwargs, + } + hf_processor = self._get_hf_processor(**processor_init_kwargs) + + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + processor_data, passthrough_data = self._get_processor_data(mm_data) # filter mm_processor_kwargs used in processor call mm_processor_kwargs = resolve_mm_processor_kwargs( From 91384bf0a82a8ed93a0a8212defdec9367e59379 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 13 Dec 2024 23:26:33 +0800 Subject: [PATCH 03/23] fix ultravox processor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index cf2b95671c352..140c8360fbcdc 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,6 +2,7 @@ """PyTorch Ultravox model.""" import math +from collections import defaultdict from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -113,6 +114,20 @@ def __init__(self, ctx: InputProcessingContext) -> None: metadata=create_metadata_for_ultravox(ctx), ) + def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): + tokenizer = self._get_tokenizer() + hf_inputs = defaultdict(list) + for audio, sr in mm_data["audio"]: + data = {"audio": audio, "sampling_rate": sr} + processed_inputs = super()._apply_hf_processor( + prompt, data, mm_processor_kwargs) + prompt = tokenizer.decode(processed_inputs["input_ids"][0], + skip_special_tokens=True) + hf_inputs["audio_features"].append( + processed_inputs["audio_values"].squeeze(0)) + hf_inputs["input_ids"] = processed_inputs["input_ids"] + return hf_inputs + def _get_processor_data(self, mm_data): processor_data, passthrough_data = super()._get_processor_data(mm_data) if "audios" in processor_data: From 782bd61a79da8e7f01092ba0511ebb10043d5908 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Dec 2024 02:00:57 +0800 Subject: [PATCH 04/23] fix placeholder padding Signed-off-by: Isotr0py <2037008807@qq.com> --- .../audio_language/test_ultravox.py | 2 +- vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/ultravox.py | 19 ++++++++++++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index e100c6b9bb906..519af8efe3143 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -16,7 +16,7 @@ AudioTuple = Tuple[np.ndarray, int] -VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" +VLLM_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>" CHUNKED_PREFILL_KWARGS = { diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c2054dcbfce0e..aaa5cd759366a 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -418,7 +418,7 @@ def _placeholder_str(self, modality: ModalityStr, raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": if model_type == "ultravox": - return "<|reserved_special_token_0|>" + return "<|audio|>" if model_type == "qwen2_audio": return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 140c8360fbcdc..14d00f24739e5 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -39,6 +39,7 @@ merge_multimodal_embeddings_from_map) _AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_PLACEHOLDER_STR = "<|reserved_special_token_0|>" _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -116,18 +117,34 @@ def __init__(self, ctx: InputProcessingContext) -> None: def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): tokenizer = self._get_tokenizer() + feature_extractor = whisper_feature_extractor(self.ctx) hf_inputs = defaultdict(list) for audio, sr in mm_data["audio"]: + if sr != feature_extractor.sampling_rate: + try: + import librosa + except ImportError as exc: + raise ImportError( + "Please install vllm[audio] for audio support.") from exc + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate data = {"audio": audio, "sampling_rate": sr} processed_inputs = super()._apply_hf_processor( prompt, data, mm_processor_kwargs) prompt = tokenizer.decode(processed_inputs["input_ids"][0], - skip_special_tokens=True) + skip_special_tokens=False) hf_inputs["audio_features"].append( processed_inputs["audio_values"].squeeze(0)) hf_inputs["input_ids"] = processed_inputs["input_ids"] return hf_inputs + def _get_hf_processor(self, **mm_processor_kwargs): + hf_processor = super()._get_hf_processor(**mm_processor_kwargs) + hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_STR + return hf_processor + def _get_processor_data(self, mm_data): processor_data, passthrough_data = super()._get_processor_data(mm_data) if "audios" in processor_data: From 57c7ec99a81c0b51e8f7beb9a512d377ff7b3271 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Dec 2024 13:21:29 +0800 Subject: [PATCH 05/23] add comments Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 33 ++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 14d00f24739e5..d8999cd0fb724 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -115,22 +115,28 @@ def __init__(self, ctx: InputProcessingContext) -> None: metadata=create_metadata_for_ultravox(ctx), ) + def _resample_audio(self, audio, sr): + # resample audio to the model's sampling rate + feature_extractor = whisper_feature_extractor(self.ctx) + if sr != feature_extractor.sampling_rate: + try: + import librosa + except ImportError as exc: + raise ImportError( + "Please install vllm[audio] for audio support.") from exc + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate + return {"audio": audio, "sampling_rate": sr} + def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): + # Ultravox processor doesn't support multiple inputs, + # therefore we need to input text and audio one by one tokenizer = self._get_tokenizer() - feature_extractor = whisper_feature_extractor(self.ctx) hf_inputs = defaultdict(list) for audio, sr in mm_data["audio"]: - if sr != feature_extractor.sampling_rate: - try: - import librosa - except ImportError as exc: - raise ImportError( - "Please install vllm[audio] for audio support.") from exc - audio = librosa.resample(audio, - orig_sr=sr, - target_sr=feature_extractor.sampling_rate) - sr = feature_extractor.sampling_rate - data = {"audio": audio, "sampling_rate": sr} + data = self._resample_audio(audio, sr) processed_inputs = super()._apply_hf_processor( prompt, data, mm_processor_kwargs) prompt = tokenizer.decode(processed_inputs["input_ids"][0], @@ -141,11 +147,14 @@ def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): return hf_inputs def _get_hf_processor(self, **mm_processor_kwargs): + # Ultravox processor use eot_token_id as the audio placeholder token, + # we replace it with <|reserved_special_token_0|> for convenience. hf_processor = super()._get_hf_processor(**mm_processor_kwargs) hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_STR return hf_processor def _get_processor_data(self, mm_data): + # Ultravox uses "audio" instead of "audios" as calling keyword processor_data, passthrough_data = super()._get_processor_data(mm_data) if "audios" in processor_data: processor_data["audio"] = processor_data.pop("audios") From c1a9cefd5a2b91043ccf8829e6c0d20bfb039bb9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Dec 2024 14:29:01 +0800 Subject: [PATCH 06/23] update example Signed-off-by: Isotr0py <2037008807@qq.com> --- examples/offline_inference_audio_language.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 050b791b62adb..a47650d575fb7 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -25,10 +25,8 @@ def run_ultravox(question: str, audio_count: int): tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ - 'role': - 'user', - 'content': - "<|reserved_special_token_0|>\n" * audio_count + question + 'role': 'user', + 'content': "<|audio|>\n" * audio_count + question }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, From 89416a875712b0ae9604a11a15f9e5d95f8c28cf Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Dec 2024 14:55:06 +0800 Subject: [PATCH 07/23] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- examples/offline_inference_audio_language.py | 4 +- vllm/model_executor/models/ultravox.py | 44 +++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index a47650d575fb7..68b786961b14a 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -32,7 +32,9 @@ def run_ultravox(question: str, audio_count: int): tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) + llm = LLM(model=model_name, + trust_remote_code=True, + limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index d8999cd0fb724..ec63fad9bbd0a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,14 +4,15 @@ import math from collections import defaultdict from functools import cached_property, lru_cache -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) import numpy as np import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F +from transformers import BatchFeature, ProcessorMixin from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder @@ -23,8 +24,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - NestedTensors) +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalData, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, InputProcessingContext, ModalityProcessingMetadata, @@ -78,8 +79,10 @@ def dummy_audio_for_ultravox( ctx: InputContext, audio_count: int, ): + # dummy audio is set to default 16kHz sampling rate feature_extractor = whisper_feature_extractor(ctx) - audio = np.array([0.0] * feature_extractor.chunk_length) + audio_len = feature_extractor.chunk_length * feature_extractor.sampling_rate + audio = np.array([0.0] * audio_len) return {"audio": [audio] * audio_count} @@ -89,8 +92,12 @@ def dummy_mm_kwargs_for_ultravox(ctx: InputProcessingContext, data = dummy_audio_for_ultravox(ctx=ctx, audio_count=mm_counts["audio"]) hf_processor = ctx.get_hf_processor() - audio_processor = hf_processor.audio_processor # type: ignore - hf_inputs = audio_processor(data['audio'], return_tensors="pt") + audio_processor = hf_processor.audio_processor + hf_inputs = audio_processor( + audio=data['audio'], + sampling_rate=16000, + return_tensors="pt", + ) return MultiModalKwargs(**hf_inputs) @@ -115,7 +122,11 @@ def __init__(self, ctx: InputProcessingContext) -> None: metadata=create_metadata_for_ultravox(ctx), ) - def _resample_audio(self, audio, sr): + def _resample_audio( + self, + audio: np.ndarray, + sr: int, + ) -> Dict[str, Union[np.ndarray, int]]: # resample audio to the model's sampling rate feature_extractor = whisper_feature_extractor(self.ctx) if sr != feature_extractor.sampling_rate: @@ -130,7 +141,12 @@ def _resample_audio(self, audio, sr): sr = feature_extractor.sampling_rate return {"audio": audio, "sampling_rate": sr} - def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): + def _apply_hf_processor( + self, + prompt: str, + mm_data: MultiModalData, + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: # Ultravox processor doesn't support multiple inputs, # therefore we need to input text and audio one by one tokenizer = self._get_tokenizer() @@ -146,14 +162,20 @@ def _apply_hf_processor(self, prompt, mm_data, mm_processor_kwargs): hf_inputs["input_ids"] = processed_inputs["input_ids"] return hf_inputs - def _get_hf_processor(self, **mm_processor_kwargs): + def _get_hf_processor( + self, + **mm_processor_kwargs: Mapping[str, object], + ) -> ProcessorMixin: # Ultravox processor use eot_token_id as the audio placeholder token, # we replace it with <|reserved_special_token_0|> for convenience. hf_processor = super()._get_hf_processor(**mm_processor_kwargs) hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_STR return hf_processor - def _get_processor_data(self, mm_data): + def _get_processor_data( + self, + mm_data: MultiModalData, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Ultravox uses "audio" instead of "audios" as calling keyword processor_data, passthrough_data = super()._get_processor_data(mm_data) if "audios" in processor_data: From 9693691b0d42f81a04d6f62a78c9eba4acc5f21f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Dec 2024 14:59:42 +0800 Subject: [PATCH 08/23] remove unused code Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/multimodal/processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 31b2437afae45..5486ae91c8121 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -620,8 +620,6 @@ def _apply_hf_processor( } hf_processor = self._get_hf_processor(**processor_init_kwargs) - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() processor_data, passthrough_data = self._get_processor_data(mm_data) # filter mm_processor_kwargs used in processor call From 08a34226eb1d4bbe5682ff94d9ac5c63d72b7669 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 14:52:48 +0800 Subject: [PATCH 09/23] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/multimodal/processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index f2d00ec3f0cfb..b2cdc8f210ac0 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -597,10 +597,7 @@ def _find_placeholders( def _get_processor_data( self, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self._get_hf_processor(**mm_processor_kwargs) - processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() for k, v in mm_data.items(): From d72fe45f63b519d1c9c00d0f99a6b777a5ae5b19 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 15:23:50 +0800 Subject: [PATCH 10/23] refactor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 81 ++++++++++---------------- vllm/multimodal/processing.py | 6 +- 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ec63fad9bbd0a..a60bf993294e8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -28,8 +28,7 @@ MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, InputProcessingContext, - ModalityProcessingMetadata, - MultiModalProcessingMetadata, + MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -75,53 +74,8 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) -def dummy_audio_for_ultravox( - ctx: InputContext, - audio_count: int, -): - # dummy audio is set to default 16kHz sampling rate - feature_extractor = whisper_feature_extractor(ctx) - audio_len = feature_extractor.chunk_length * feature_extractor.sampling_rate - audio = np.array([0.0] * audio_len) - return {"audio": [audio] * audio_count} - - -def dummy_mm_kwargs_for_ultravox(ctx: InputProcessingContext, - mm_counts: Mapping[str, int]): - - data = dummy_audio_for_ultravox(ctx=ctx, audio_count=mm_counts["audio"]) - - hf_processor = ctx.get_hf_processor() - audio_processor = hf_processor.audio_processor - hf_inputs = audio_processor( - audio=data['audio'], - sampling_rate=16000, - return_tensors="pt", - ) - - return MultiModalKwargs(**hf_inputs) - - -def create_metadata_for_ultravox( - ctx: InputProcessingContext) -> MultiModalProcessingMetadata: - return { - "audio": - ModalityProcessingMetadata(prompt_repls=[ - PromptReplacement(target=[_AUDIO_PLACEHOLDER_TOKEN], - repl_unit=[_AUDIO_PLACEHOLDER_TOKEN], - repl_count=get_ultravox_max_audio_tokens(ctx)), - ]), - } - - class UltravoxProcessor(BaseMultiModalProcessor): - def __init__(self, ctx: InputProcessingContext) -> None: - super().__init__( - ctx=ctx, - metadata=create_metadata_for_ultravox(ctx), - ) - def _resample_audio( self, audio: np.ndarray, @@ -182,11 +136,38 @@ def _get_processor_data( processor_data["audio"] = processor_data.pop("audios") return processor_data, passthrough_data - def _get_dummy_mm_kwargs( + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_inputs: BatchFeature, + mm_processor_kwargs: Mapping[str, object], + ) -> list[PromptReplacement]: + max_audio_tokens = get_ultravox_max_audio_tokens(self.ctx) + return [ + PromptReplacement( + modality="audio", + target="<|audio|>", + replacement=[_AUDIO_PLACEHOLDER_TOKEN] * max_audio_tokens, + ) + ] + + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: - return dummy_mm_kwargs_for_ultravox(self.ctx, mm_counts) + ) -> ProcessorInputs: + feature_extractor = whisper_feature_extractor(self.ctx) + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + + audio_count = mm_counts["audio"] + audio = np.array([0.0] * audio_len) + data = {"audio": [(audio, sampling_rate)] * audio_count} + + return ProcessorInputs( + prompt_text="<|audio|>" * audio_count, + mm_data=data, + mm_processor_kwargs={}, + ) class StackAudioFrames(nn.Module): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index b2cdc8f210ac0..339e193eefe20 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -625,11 +625,7 @@ def _apply_hf_processor( ) -> BatchFeature: # some mm_processor_kwargs may be used in processor initialization # instead of processor call - processor_init_kwargs = { - **self.init_mm_processor_kwargs, - **mm_processor_kwargs, - } - hf_processor = self._get_hf_processor(**processor_init_kwargs) + hf_processor = self._get_hf_processor(**mm_processor_kwargs) processor_data, passthrough_data = self._get_processor_data(mm_data) From 0b8aa47ac4da129e183188d1156c406363135184 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 15:27:35 +0800 Subject: [PATCH 11/23] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index a60bf993294e8..adb5ab1917029 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -24,10 +24,9 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalData, - MultiModalKwargs, NestedTensors) +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext, + MultiModalDataDict, MultiModalDataItems, ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors @@ -98,7 +97,7 @@ def _resample_audio( def _apply_hf_processor( self, prompt: str, - mm_data: MultiModalData, + mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: # Ultravox processor doesn't support multiple inputs, @@ -128,7 +127,7 @@ def _get_hf_processor( def _get_processor_data( self, - mm_data: MultiModalData, + mm_data: MultiModalDataDict, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Ultravox uses "audio" instead of "audios" as calling keyword processor_data, passthrough_data = super()._get_processor_data(mm_data) From d5b7cf79d68866f3ae7910be756a67bb5686f57a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 16:25:51 +0800 Subject: [PATCH 12/23] fix prompt replacement Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index adb5ab1917029..9a72de1e6de10 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -141,12 +141,22 @@ def _get_prompt_replacements( hf_inputs: BatchFeature, mm_processor_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: - max_audio_tokens = get_ultravox_max_audio_tokens(self.ctx) + hf_processor = self._get_hf_processor() + stack_factor = hf_processor.stack_factor + encoder_ds_factor = hf_processor.encoder_ds_factor + + def get_replacement_ultravox(item_idx: int): + audio_data, _ = mm_items.audio[item_idx] + audio_len = audio_data.shape[-1] + nb_encoder_frames = int(round(audio_len / encoder_ds_factor + 1e-4)) + audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) + return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len + return [ PromptReplacement( modality="audio", target="<|audio|>", - replacement=[_AUDIO_PLACEHOLDER_TOKEN] * max_audio_tokens, + replacement=get_replacement_ultravox, ) ] From 980c73193266bcc55f3a8ff2e3e81f87d69bbf30 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 16:26:28 +0800 Subject: [PATCH 13/23] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9a72de1e6de10..b9f96ffe0960b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -148,7 +148,8 @@ def _get_prompt_replacements( def get_replacement_ultravox(item_idx: int): audio_data, _ = mm_items.audio[item_idx] audio_len = audio_data.shape[-1] - nb_encoder_frames = int(round(audio_len / encoder_ds_factor + 1e-4)) + nb_encoder_frames = int(round(audio_len / encoder_ds_factor + + 1e-4)) audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len From 5cb63629738f9f5e83f9e88a625c9061f06d8b52 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 16:36:28 +0800 Subject: [PATCH 14/23] fix audio_token truncation Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b9f96ffe0960b..2344546d88e10 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -146,11 +146,14 @@ def _get_prompt_replacements( encoder_ds_factor = hf_processor.encoder_ds_factor def get_replacement_ultravox(item_idx: int): - audio_data, _ = mm_items.audio[item_idx] - audio_len = audio_data.shape[-1] + audio_data, sr = mm_items.audio[item_idx] + audio_data = self._resample_audio(audio_data, sr)["audio"] + audio_len = audio_data.shape[0] nb_encoder_frames = int(round(audio_len / encoder_ds_factor + 1e-4)) audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) + max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx) + audio_token_len = min(audio_token_len, max_audio_token_len) return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len return [ From 0854a67117a3e96c2967da09b7cbe8b9fc388987 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 16:55:56 +0800 Subject: [PATCH 15/23] fix mm_data Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 2344546d88e10..43deb2059736a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -100,11 +100,15 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: + audio_data = mm_data["audio"] + if not isinstance(audio_data, list): + audio_data = [audio_data] + # Ultravox processor doesn't support multiple inputs, # therefore we need to input text and audio one by one tokenizer = self._get_tokenizer() hf_inputs = defaultdict(list) - for audio, sr in mm_data["audio"]: + for audio, sr in audio_data: data = self._resample_audio(audio, sr) processed_inputs = super()._apply_hf_processor( prompt, data, mm_processor_kwargs) From 146fc638348f9ae76872d8923f56fcaf9954acc8 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 18:50:01 +0800 Subject: [PATCH 16/23] fix audio_token_len and online inference Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/decoder_only/audio_language/test_ultravox.py | 3 ++- vllm/model_executor/models/ultravox.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 519af8efe3143..c548cfdf53414 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -46,7 +46,8 @@ def audio(request): def server(request, audio_assets): args = [ "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f"--limit-mm-per-prompt=audio={len(audio_assets)}" + f"--limit-mm-per-prompt=audio={len(audio_assets)}", + "--trust-remote-code" ] + [ f"--{key.replace('_','-')}={value}" for key, value in request.param.items() diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 43deb2059736a..1e97e01e57bad 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -157,7 +157,7 @@ def get_replacement_ultravox(item_idx: int): 1e-4)) audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx) - audio_token_len = min(audio_token_len, max_audio_token_len) + audio_token_len = min(max(1, audio_token_len), max_audio_token_len) return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len return [ From 342048cae552d94deb8862e19d6a20afb6db45b2 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 22:38:11 +0800 Subject: [PATCH 17/23] rename Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1e97e01e57bad..549fd7f2ef213 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -73,7 +73,10 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) -class UltravoxProcessor(BaseMultiModalProcessor): +class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + + def _get_feature_extractor(self) -> WhisperFeatureExtractor: + return self._get_hf_processor().audio_processor.feature_extractor def _resample_audio( self, @@ -81,7 +84,7 @@ def _resample_audio( sr: int, ) -> Dict[str, Union[np.ndarray, int]]: # resample audio to the model's sampling rate - feature_extractor = whisper_feature_extractor(self.ctx) + feature_extractor = self._get_feature_extractor() if sr != feature_extractor.sampling_rate: try: import librosa @@ -148,6 +151,7 @@ def _get_prompt_replacements( hf_processor = self._get_hf_processor() stack_factor = hf_processor.stack_factor encoder_ds_factor = hf_processor.encoder_ds_factor + max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx) def get_replacement_ultravox(item_idx: int): audio_data, sr = mm_items.audio[item_idx] @@ -156,7 +160,6 @@ def get_replacement_ultravox(item_idx: int): nb_encoder_frames = int(round(audio_len / encoder_ds_factor + 1e-4)) audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) - max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx) audio_token_len = min(max(1, audio_token_len), max_audio_token_len) return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len @@ -172,12 +175,12 @@ def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - feature_extractor = whisper_feature_extractor(self.ctx) + feature_extractor = self._get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate audio_count = mm_counts["audio"] - audio = np.array([0.0] * audio_len) + audio = np.zeros(audio_len) data = {"audio": [(audio, sampling_rate)] * audio_count} return ProcessorInputs( @@ -307,7 +310,7 @@ def forward( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_ultravox_max_audio_tokens) -@MULTIMODAL_REGISTRY.register_processor(UltravoxProcessor) +@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): From daba237762c0da280ee932c301b4557e7debae15 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 15 Dec 2024 22:38:41 +0800 Subject: [PATCH 18/23] Update vllm/model_executor/models/ultravox.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/ultravox.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 549fd7f2ef213..53dd1f36bec4b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -110,16 +110,20 @@ def _apply_hf_processor( # Ultravox processor doesn't support multiple inputs, # therefore we need to input text and audio one by one tokenizer = self._get_tokenizer() - hf_inputs = defaultdict(list) + audio_features = [] for audio, sr in audio_data: data = self._resample_audio(audio, sr) processed_inputs = super()._apply_hf_processor( prompt, data, mm_processor_kwargs) prompt = tokenizer.decode(processed_inputs["input_ids"][0], skip_special_tokens=False) - hf_inputs["audio_features"].append( + audio_features.append( processed_inputs["audio_values"].squeeze(0)) - hf_inputs["input_ids"] = processed_inputs["input_ids"] + + hf_inputs = dict( + **processed_inputs, + audio_features=audio_features, + ) return hf_inputs def _get_hf_processor( From 7813d47e59eb83861ec6229919617f3a2c5d0d6d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 23:04:33 +0800 Subject: [PATCH 19/23] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 45 ++++++++++---------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 53dd1f36bec4b..190cbc860d3be 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ """PyTorch Ultravox model.""" import math -from collections import defaultdict from functools import cached_property, lru_cache from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -12,7 +11,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder @@ -37,8 +36,6 @@ init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings_from_map) -_AUDIO_PLACEHOLDER_TOKEN = 128002 -_AUDIO_PLACEHOLDER_STR = "<|reserved_special_token_0|>" _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -110,7 +107,8 @@ def _apply_hf_processor( # Ultravox processor doesn't support multiple inputs, # therefore we need to input text and audio one by one tokenizer = self._get_tokenizer() - audio_features = [] + audio_features, audio_token_len = [], [] + processed_inputs = {} for audio, sr in audio_data: data = self._resample_audio(audio, sr) processed_inputs = super()._apply_hf_processor( @@ -118,24 +116,17 @@ def _apply_hf_processor( prompt = tokenizer.decode(processed_inputs["input_ids"][0], skip_special_tokens=False) audio_features.append( - processed_inputs["audio_values"].squeeze(0)) - + processed_inputs.pop("audio_values").squeeze(0)) + audio_token_len.append( + processed_inputs.pop("audio_token_len").item()) + hf_inputs = dict( **processed_inputs, audio_features=audio_features, + audio_token_len=audio_token_len, ) return hf_inputs - def _get_hf_processor( - self, - **mm_processor_kwargs: Mapping[str, object], - ) -> ProcessorMixin: - # Ultravox processor use eot_token_id as the audio placeholder token, - # we replace it with <|reserved_special_token_0|> for convenience. - hf_processor = super()._get_hf_processor(**mm_processor_kwargs) - hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_STR - return hf_processor - def _get_processor_data( self, mm_data: MultiModalDataDict, @@ -153,19 +144,17 @@ def _get_prompt_replacements( mm_processor_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() - stack_factor = hf_processor.stack_factor - encoder_ds_factor = hf_processor.encoder_ds_factor - max_audio_token_len = get_ultravox_max_audio_tokens(self.ctx) + tokenizer = hf_processor.tokenizer + placeholder = hf_processor.audio_token_replacement + + placeholder_token = tokenizer.encode(placeholder, + add_special_tokens=False) + assert len(placeholder_token) == 1 + placeholder_token = placeholder_token[0] def get_replacement_ultravox(item_idx: int): - audio_data, sr = mm_items.audio[item_idx] - audio_data = self._resample_audio(audio_data, sr)["audio"] - audio_len = audio_data.shape[0] - nb_encoder_frames = int(round(audio_len / encoder_ds_factor + - 1e-4)) - audio_token_len = int(np.ceil(nb_encoder_frames / stack_factor)) - audio_token_len = min(max(1, audio_token_len), max_audio_token_len) - return [_AUDIO_PLACEHOLDER_TOKEN] * audio_token_len + audio_token_len = hf_inputs["audio_token_len"][item_idx] + return [placeholder_token] * audio_token_len return [ PromptReplacement( From 6e7b1386cf6e4c0aeeb4170897c475a10f406fbb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 15 Dec 2024 23:08:38 +0800 Subject: [PATCH 20/23] handle no audio data Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 190cbc860d3be..faa5f5775cb18 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -100,6 +100,10 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: + if not mm_data: + return super()._apply_hf_processor(prompt, mm_data, + mm_processor_kwargs) + audio_data = mm_data["audio"] if not isinstance(audio_data, list): audio_data = [audio_data] From ca58f8b56bdabc349d4c6809e9d28a17884a6ed8 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 15 Dec 2024 23:39:49 +0800 Subject: [PATCH 21/23] Update vllm/model_executor/models/ultravox.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/ultravox.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index faa5f5775cb18..11d41c9500a0b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -124,12 +124,11 @@ def _apply_hf_processor( audio_token_len.append( processed_inputs.pop("audio_token_len").item()) - hf_inputs = dict( + return dict( **processed_inputs, audio_features=audio_features, audio_token_len=audio_token_len, ) - return hf_inputs def _get_processor_data( self, From e1fdd36d072bed1dcc5471a8214d10f1764dba3d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 16 Dec 2024 00:11:52 +0800 Subject: [PATCH 22/23] cleanup replacement Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/ultravox.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 11d41c9500a0b..ebaa8a4c4f38a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -100,7 +100,7 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - if not mm_data: + if not mm_data or not mm_data.get("audio", None): return super()._apply_hf_processor(prompt, mm_data, mm_processor_kwargs) @@ -147,17 +147,11 @@ def _get_prompt_replacements( mm_processor_kwargs: Mapping[str, object], ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() - tokenizer = hf_processor.tokenizer placeholder = hf_processor.audio_token_replacement - placeholder_token = tokenizer.encode(placeholder, - add_special_tokens=False) - assert len(placeholder_token) == 1 - placeholder_token = placeholder_token[0] - def get_replacement_ultravox(item_idx: int): audio_token_len = hf_inputs["audio_token_len"][item_idx] - return [placeholder_token] * audio_token_len + return placeholder * audio_token_len return [ PromptReplacement( From 8ef0b23c116d722b66a9cdba8863158a0bb34c08 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 16 Dec 2024 14:25:16 +0800 Subject: [PATCH 23/23] fix audio entrypoint and pp test Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/distributed/test_pipeline_parallel.py | 2 +- tests/entrypoints/openai/test_audio.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 85d408efafe96..ddbf40f089407 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -214,7 +214,7 @@ def iter_params(self, model_name: str): "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), - "fixie-ai/ultravox-v0_3": PPTestSettings.fast(), + "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True), # [Encoder-decoder] # TODO: Implement PP # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index a74109e2f5120..b579dcbb5c402 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -25,6 +25,7 @@ def server(): "--max-num-seqs", "5", "--enforce-eager", + "--trust-remote-code", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: