diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b903253eb492..f9a484269afc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -506,6 +506,8 @@ title: QDQBert - local: model_doc/qwen2 title: Qwen2 + - local: model_doc/qwen2_audio + title: Qwen2Audio - local: model_doc/qwen2_moe title: Qwen2MoE - local: model_doc/rag diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 716dc6511dd3..c873fcddd28a 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -256,6 +256,7 @@ Flax), PyTorch, and/or TensorFlow. | [PVTv2](model_doc/pvt_v2) | ✅ | ❌ | ❌ | | [QDQBert](model_doc/qdqbert) | ✅ | ❌ | ❌ | | [Qwen2](model_doc/qwen2) | ✅ | ❌ | ❌ | +| [Qwen2Audio](model_doc/qwen2_audio) | ✅ | ❌ | ❌ | | [Qwen2MoE](model_doc/qwen2_moe) | ✅ | ❌ | ❌ | | [RAG](model_doc/rag) | ✅ | ✅ | ❌ | | [REALM](model_doc/realm) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md new file mode 100644 index 000000000000..f399a7e7320c --- /dev/null +++ b/docs/source/en/model_doc/qwen2_audio.md @@ -0,0 +1,198 @@ + + +# Qwen2Audio + +## Overview + +The Qwen2-Audio is the new model series of large audio-language models from the Qwen team. Qwen2-Audio is capable of accepting various audio signal inputs and performing audio analysis or direct textual responses with regard to speech instructions. We introduce two distinct audio interaction modes: + +* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input +* audio analysis: users could provide audio and text instructions for analysis during the interaction + +It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou. + +The abstract from the paper is the following: + +*We introduce the latest progress of Qwen-Audio, a large-scale audio-language model called Qwen2-Audio, which is capable of accepting various audio signal inputs and performing audio analysis or direct textual responses with regard to speech instructions. In contrast to complex hierarchical tags, we have simplified the pre-training process by utilizing natural language prompts for different data and tasks, and have further expanded the data volume. We have boosted the instruction-following capability of Qwen2-Audio and implemented two distinct audio interaction modes for voice chat and audio analysis. In the voice chat mode, users can freely engage in voice interactions with Qwen2-Audio without text input. In the audio analysis mode, users could provide audio and text instructions for analysis during the interaction. Note that we do not use any system prompts to switch between voice chat and audio analysis modes. Qwen2-Audio is capable of intelligently comprehending the content within audio and following voice commands to respond appropriately. For instance, in an audio segment that simultaneously contains sounds, multi-speaker conversations, and a voice command, Qwen2-Audio can directly understand the command and provide an interpretation and response to the audio. Additionally, DPO has optimized the model's performance in terms of factuality and adherence to desired behavior. According to the evaluation results from AIR-Bench, Qwen2-Audio outperformed previous SOTAs, such as Gemini-1.5-pro, in tests focused on audio-centric instruction-following capabilities. Qwen2-Audio is open-sourced with the aim of fostering the advancement of the multi-modal language community. * + + +## Usage tips + +`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen) + +In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose. + +### Voice Chat Inference +In the voice chat mode, users can freely engage in voice interactions with Qwen2-Audio without text input: +```python +from io import BytesIO +from urllib.request import urlopen +import librosa +from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor + +processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") +model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto") + +conversation = [ + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"}, + ]}, + {"role": "assistant", "content": "Yes, the speaker is female and in her twenties."}, + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav"}, + ]}, +] +text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) +audios = [] +for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ele["type"] == "audio": + audios.append(librosa.load( + BytesIO(urlopen(ele['audio_url']).read()), + sr=processor.feature_extractor.sampling_rate)[0] + ) + +inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True) +inputs.input_ids = inputs.input_ids.to("cuda") + +generate_ids = model.generate(**inputs, max_length=256) +generate_ids = generate_ids[:, inputs.input_ids.size(1):] + +response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] +``` + +### Audio Analysis Inference +In the audio analysis, users could provide both audio and text instructions for analysis: +```python +from io import BytesIO +from urllib.request import urlopen +import librosa +from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor + +processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") +model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto") + +conversation = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + {"type": "text", "text": "What's that sound?"}, + ]}, + {"role": "assistant", "content": "It is the sound of glass shattering."}, + {"role": "user", "content": [ + {"type": "text", "text": "What can you do when you hear that?"}, + ]}, + {"role": "assistant", "content": "Stay alert and cautious, and check if anyone is hurt or if there is any damage to property."}, + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac"}, + {"type": "text", "text": "What does the person say?"}, + ]}, +] +text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) +audios = [] +for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ele["type"] == "audio": + audios.append( + librosa.load( + BytesIO(urlopen(ele['audio_url']).read()), + sr=processor.feature_extractor.sampling_rate)[0] + ) + +inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True) +inputs.input_ids = inputs.input_ids.to("cuda") + +generate_ids = model.generate(**inputs, max_length=256) +generate_ids = generate_ids[:, inputs.input_ids.size(1):] + +response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] +``` + +### Batch Inference +We also support batch inference: +```python +from io import BytesIO +from urllib.request import urlopen +import librosa +from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor + +processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") +model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto") + +conversation1 = [ + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + {"type": "text", "text": "What's that sound?"}, + ]}, + {"role": "assistant", "content": "It is the sound of glass shattering."}, + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"}, + {"type": "text", "text": "What can you hear?"}, + ]} +] + +conversation2 = [ + {"role": "user", "content": [ + {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac"}, + {"type": "text", "text": "What does the person say?"}, + ]}, +] + +conversations = [conversation1, conversation2] + +text = [processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) for conversation in conversations] + +audios = [] +for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ele["type"] == "audio": + audios.append( + librosa.load( + BytesIO(urlopen(ele['audio_url']).read()), + sr=processor.feature_extractor.sampling_rate)[0] + ) + +inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True) +inputs['input_ids'] = inputs['input_ids'].to("cuda") +inputs.input_ids = inputs.input_ids.to("cuda") + +generate_ids = model.generate(**inputs, max_length=256) +generate_ids = generate_ids[:, inputs.input_ids.size(1):] + +response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) +``` + +## Qwen2AudioConfig + +[[autodoc]] Qwen2AudioConfig + +## Qwen2AudioConfig + +[[autodoc]] Qwen2AudioEncoderConfig + +## Qwen2AudioProcessor + +[[autodoc]] Qwen2AudioProcessor + +## Qwen2AudioForConditionalGeneration + +[[autodoc]] Qwen2AudioForConditionalGeneration + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 149c25d17a21..df1e64e36877 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures: * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) +* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) @@ -227,6 +228,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) +* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 72a42380b371..b291ee828933 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -655,6 +655,11 @@ "Qwen2Config", "Qwen2Tokenizer", ], + "models.qwen2_audio": [ + "Qwen2AudioConfig", + "Qwen2AudioEncoderConfig", + "Qwen2AudioProcessor", + ], "models.qwen2_moe": ["Qwen2MoeConfig"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.recurrent_gemma": ["RecurrentGemmaConfig"], @@ -2980,6 +2985,13 @@ "Qwen2PreTrainedModel", ] ) + _import_structure["models.qwen2_audio"].extend( + [ + "Qwen2AudioEncoder", + "Qwen2AudioForConditionalGeneration", + "Qwen2AudioPreTrainedModel", + ] + ) _import_structure["models.qwen2_moe"].extend( [ "Qwen2MoeForCausalLM", @@ -5378,6 +5390,11 @@ from .models.pvt import PvtConfig from .models.pvt_v2 import PvtV2Config from .models.qwen2 import Qwen2Config, Qwen2Tokenizer + from .models.qwen2_audio import ( + Qwen2AudioConfig, + Qwen2AudioEncoderConfig, + Qwen2AudioProcessor, + ) from .models.qwen2_moe import Qwen2MoeConfig from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.recurrent_gemma import RecurrentGemmaConfig @@ -7390,6 +7407,11 @@ Qwen2Model, Qwen2PreTrainedModel, ) + from .models.qwen2_audio import ( + Qwen2AudioEncoder, + Qwen2AudioForConditionalGeneration, + Qwen2AudioPreTrainedModel, + ) from .models.qwen2_moe import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ad6b60ca1f7e..8e917af7c681 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -189,6 +189,7 @@ pvt, pvt_v2, qwen2, + qwen2_audio, qwen2_moe, rag, recurrent_gemma, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py old mode 100755 new mode 100644 index 129cd4e2d8b8..d06c99e18f36 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -208,6 +208,8 @@ ("pvt_v2", "PvtV2Config"), ("qdqbert", "QDQBertConfig"), ("qwen2", "Qwen2Config"), + ("qwen2_audio", "Qwen2AudioConfig"), + ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_moe", "Qwen2MoeConfig"), ("rag", "RagConfig"), ("realm", "RealmConfig"), @@ -504,6 +506,8 @@ ("pvt_v2", "PVTv2"), ("qdqbert", "QDQBert"), ("qwen2", "Qwen2"), + ("qwen2_audio", "Qwen2Audio"), + ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoE"), ("rag", "RAG"), ("realm", "REALM"), @@ -642,6 +646,7 @@ ("maskformer-swin", "maskformer"), ("xclip", "x_clip"), ("clip_vision_model", "clip"), + ("qwen2_audio_encoder", "qwen2_audio"), ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py old mode 100755 new mode 100644 index baa87e164b46..0cf0752e1060 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -196,6 +196,7 @@ ("pvt_v2", "PvtV2Model"), ("qdqbert", "QDQBertModel"), ("qwen2", "Qwen2Model"), + ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("recurrent_gemma", "RecurrentGemmaModel"), ("reformer", "ReformerModel"), @@ -323,6 +324,7 @@ ("nllb-moe", "NllbMoeForConditionalGeneration"), ("openai-gpt", "OpenAIGPTLMHeadModel"), ("paligemma", "PaliGemmaForConditionalGeneration"), + ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), ("retribert", "RetriBertModel"), ("roberta", "RobertaForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), @@ -829,6 +831,7 @@ ("pegasus_x", "PegasusXForConditionalGeneration"), ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), + ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), ("seamless_m4t", "SeamlessM4TForTextToText"), ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 1ab136a1e74c..7877343d5318 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -82,6 +82,7 @@ ("paligemma", "PaliGemmaProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pop2piano", "Pop2PianoProcessor"), + ("qwen2_audio", "Qwen2AudioProcessor"), ("sam", "SamProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 3d30a005e7d3..5df108a0faf3 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -392,6 +392,7 @@ "Qwen2TokenizerFast" if is_tokenizers_available() else None, ), ), + ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ( "qwen2_moe", ( diff --git a/src/transformers/models/qwen2_audio/__init__.py b/src/transformers/models/qwen2_audio/__init__.py new file mode 100644 index 000000000000..456378e2a53c --- /dev/null +++ b/src/transformers/models/qwen2_audio/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_qwen2_audio": ["Qwen2AudioConfig", "Qwen2AudioEncoderConfig"], + "processing_qwen2_audio": ["Qwen2AudioProcessor"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_qwen2_audio"] = [ + "Qwen2AudioForConditionalGeneration", + "Qwen2AudioPreTrainedModel", + "Qwen2AudioEncoder", + ] + + +if TYPE_CHECKING: + from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig + from .processing_qwen2_audio import Qwen2AudioProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_qwen2_audio import ( + Qwen2AudioEncoder, + Qwen2AudioForConditionalGeneration, + Qwen2AudioPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py new file mode 100644 index 000000000000..deb276f33472 --- /dev/null +++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2Audio model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class Qwen2AudioEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2AudioEncoder`]. It is used to instantiate a + Qwen2-Audio audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2AudioProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + + Example: + + ```python + >>> from transformers import Qwen2AudioEncoderConfig, Qwen2AudioEncoder + + >>> # Initializing a Qwen2AudioEncoderConfig + >>> configuration = Qwen2AudioEncoderConfig() + + >>> # Initializing a Qwen2AudioEncoder (with random weights) + >>> model = Qwen2AudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + d_model=1280, + dropout=0.0, + attention_dropout=0.0, + activation_function="gelu", + activation_dropout=0.0, + scale_embedding=False, + init_std=0.02, + max_source_positions=1500, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.num_hidden_layers = encoder_layers + self.init_std = init_std + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + + +class Qwen2AudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2AudioForConditionalGeneration`]. It is used to instantiate an + Qwen2-Audio model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2-Audio. + + e.g. [Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the audio backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The image token index to encode the image prompt. + + Example: + + ```python + >>> from transformers import Qwen2AudioForConditionalGeneration, Qwen2AudioConfig, Qwen2AudioEncoderConfig, Qwen2Config + + >>> # Initializing a Qwen2AudioEncoder config + >>> audio_config = Qwen2AudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2Audio configuration + >>> configuration = Qwen2AudioConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2AudioForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_audio" + is_composition = False + + def __init__( + self, + audio_config=None, + text_config=None, + audio_token_index=151646, + **kwargs, + ): + self.audio_token_index = audio_token_index + + if isinstance(audio_config, dict): + audio_config["model_type"] = ( + audio_config["model_type"] if "model_type" in audio_config else "qwen2_audio_encoder" + ) + audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config) + elif audio_config is None: + audio_config = CONFIG_MAPPING["qwen2_audio_encoder"]( + d_model=1280, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + encoder_layers=32, + num_mel_bins=128, + max_source_positions=1500, + scale_embedding=False, + activation_function="gelu", + ) + + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]() + + self.text_config = text_config + + super().__init__(**kwargs) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py new file mode 100644 index 000000000000..855c16fb8c18 --- /dev/null +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -0,0 +1,1378 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2Audio model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache, StaticCache +from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2AudioConfig" + + +@dataclass +class Qwen2AudioCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2Audio causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + attention_mask (`torch.FloatTensor`, *optional*): + Attentions mask, used to update attention mask and position_ids. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + attention_mask: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention with Whisper->Qwen2Audio +class Qwen2AudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + layer_idx: Optional[int] = None, + config: Optional[Qwen2AudioConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio +class Qwen2AudioFlashAttention2(Qwen2AudioAttention): + """ + Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) + # Qwen2AudioFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio +class Qwen2AudioSdpaAttention(Qwen2AudioAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2AUDIO_ATTENTION_CLASSES = { + "eager": Qwen2AudioAttention, + "flash_attention_2": Qwen2AudioFlashAttention2, + "sdpa": Qwen2AudioSdpaAttention, +} + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO +class Qwen2AudioEncoderLayer(nn.Module): + def __init__(self, config: Qwen2AudioConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = QWEN2AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +QWEN2AUDIO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2AudioConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2Audio Model outputting raw hidden-states without any specific head on top.", + QWEN2AUDIO_START_DOCSTRING, +) +class Qwen2AudioPreTrainedModel(PreTrainedModel): + config_class = Qwen2AudioConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2AudioAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + # important: this ported version of Qwen2Audio isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_config.init_std + + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +QWEN2AUDIOENCODER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2AudioEncoderConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The audio model from Qwen2Audio without any head or projection on top.""", + QWEN2AUDIOENCODER_START_DOCSTRING, +) +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoder with Whisper->Qwen2Audio +class Qwen2AudioEncoder(Qwen2AudioPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2AudioEncoderLayer`]. + + Args: + config: Qwen2AudioEncoderConfig + """ + + # Ignore copy + config_class = Qwen2AudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2AudioEncoderLayer"] + + def __init__(self, config: Qwen2AudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) + + self.layers = nn.ModuleList([Qwen2AudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + # Ignore copy + self.avg_pooler = nn.AvgPool1d(2, stride=2) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] + if input_features.shape[-1] != expected_seq_length: + raise ValueError( + f"Qwen2Audio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Ignore copy + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.avg_pooler(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class Qwen2AudioMultiModalProjector(nn.Module): + def __init__(self, config: Qwen2AudioConfig): + super().__init__() + self.linear = nn.Linear(config.audio_config.d_model, config.text_config.hidden_size, bias=True) + + def forward(self, audio_features): + hidden_states = self.linear(audio_features) + return hidden_states + + +QWEN2AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The QWEN2AUDIO model which consists of a audio backbone and a language model.""", + QWEN2AUDIO_START_DOCSTRING, +) +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel): + def __init__(self, config: Qwen2AudioConfig): + super().__init__(config) + self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation) + + self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.post_init() + + @property + def padding_side(self): + return self._padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + if padding_side not in ["left", "right"]: + raise ValueError(f"{padding_side} is not `left` or `right`.") + self._padding_side = padding_side + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder + def get_decoder(self): + return self.language_model.get_decoder() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights + def tie_weights(self): + return self.language_model.tie_weights() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_audio_features( + self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels + ): + """ + Merge input_ids with with audio features into final embeddings + + Args: + audio_features (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): + All audio vectors of all audios in the batch + num_audio_tokens (`torch.LongTensor` of shape `(num_audios)`): + The length of audio embeddings of each audio as stacked in `audio_features` + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with audio embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with audio token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + labels need to be recalculated to support training (if provided) + Returns: + final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids + + Explanation: + each audio has variable length embeddings, with length specified by num_audio_tokens + audio_features is concatenation of all audio embed vectors + task: fill each <|AUDIO|> with the correct number of audio embeddings + Example: + X (5 tokens), Y (3 tokens), Z (8 tokens) + X, Y are in the same sequence (in-context learning) + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + Edge cases: + * If tokens are same but audio token sizes are different, then cannot infer left or right padding + ```python + url1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + audio1, _ = librosa.load(BytesIO(urlopen(url1).read()), sr=processor.feature_extractor.sampling_rate) + url2 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav" + audio2, _ = librosa.load(BytesIO(urlopen(url2).read()), sr=processor.feature_extractor.sampling_rate) + prompts = [ + "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]", + "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]", + ] + inputs = processor(text=prompts, audios=[audio1, audio2], return_tensors='pt', padding=True).to("cuda") + audio1 has 101 tokens, while audio2 has 72 tokens + ``` + + input_ids: [ + a b c d X g h + i j Y k l m n + ] + where X is 3 tokens while Y is 5, this mean after merge + if left-padding (batched generation) + input_ids should be: [ + _ _ a b c d X X X g h + i j Y Y Y Y Y k l m n + ] + elif (right padding) (training) + input_ids should be: [ + a b c d X X X g h _ _ + i j Y Y Y Y Y k l m n + ] + """ + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + num_audio_tokens.device + ) < num_audio_tokens.unsqueeze(1) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) + batch_size, sequence_length = input_ids.shape + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = True + if batch_size > 1: + if _left_padding and not _right_padding: + left_padding = True + elif not _left_padding and _right_padding: + left_padding = False + elif not _left_padding and not _right_padding: + # both side is 1, so cannot tell + left_padding = self.padding_side == "left" + else: + # invalid attention_mask + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + + # 1. Create a mask to know where special audio tokens are + special_audio_token_mask = input_ids == self.config.audio_token_index + num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) + + # In case the Audio model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + num_audio_tokens = num_audio_tokens.to(target_device) + batch_indices, non_audio_indices = torch.where( + (input_ids != self.config.audio_token_index) & (attention_mask == 1) + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged audio-text sequence. + # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. + # `torch.cumsum` computes how each audio token shifts subsequent text token positions. + token_placeholder_num = torch.zeros_like(input_ids) + token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 + token_placeholder_num = token_placeholder_num + 1 + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = token_placeholder_num.sum(-1).max() + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + batch_indices, non_audio_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_audio_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_token_num, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_token_num, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_token_num), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device + ) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "