From 20dfbff2252d22aeb52f4cfc749a2ba47ad9f281 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 17 May 2024 14:26:29 +0100 Subject: [PATCH 01/62] add rope --- parler_tts/configuration_parler_tts.py | 8 ++ parler_tts/modeling_parler_tts.py | 159 ++++++++++++++++++++++--- 2 files changed, 150 insertions(+), 17 deletions(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index 5d631c4..dba2b8b 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -74,6 +74,10 @@ class ParlerTTSDecoderConfig(PretrainedConfig): The number of parallel codebooks forwarded to the model. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether input and output word embeddings should be tied. + rope_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use ROPE or absolute positional embeddings. + rope_theta (`float`, *optional*, defaults to 100000.0): + The base period of the RoPE embeddings. """ model_type = "parler_tts_decoder" @@ -100,6 +104,8 @@ def __init__( bos_token_id=2049, eos_token_id=2048, tie_word_embeddings=False, + rope_embeddings=False, + rope_theta=10_000.0, **kwargs, ): self.vocab_size = vocab_size @@ -117,6 +123,8 @@ def __init__( self.use_cache = use_cache self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.num_codebooks = num_codebooks + self.rope_embeddings = rope_embeddings + self.rope_theta = rope_theta super().__init__( pad_token_id=pad_token_id, diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 1437465..9465fd3 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -18,7 +18,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List import torch import torch.nn as nn @@ -222,8 +222,77 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): self.make_weights(seq_len + self.offset, self.embedding_dim) return self.weights.index_select(0, position_ids.view(-1)).detach() +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->ParlerTTS +class ParlerTTSRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, is_cross_attention, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + is_cross_attention (`bool`): + Whether this is a cross-attention layer. If so, we don't apply ROPE to the key-states, since we assume + the encoder hidden-states have been computed with suitable positional encoding. Adding ROPE on-top of + such embeddings is hypothesised to disrupt the original positional encodings. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) if not is_cross_attention else k + return q_embed, k_embed -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS class ParlerTTSAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -235,7 +304,7 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: Optional[ParlerTTSConfig] = None, + config: Optional[ParlerTTSDecoderConfig] = None, ): super().__init__() self.embed_dim = embed_dim @@ -258,6 +327,13 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + if config.rope_embeddings: + self.rotary_emb = ParlerTTSRotaryEmbedding( + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -267,6 +343,7 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -317,8 +394,13 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + query_states = self._shape(query_states, tgt_len, bsz) + if self.config.rope_embeddings: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -382,7 +464,6 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS class ParlerTTSDecoderLayer(nn.Module): def __init__(self, config: ParlerTTSDecoderConfig): super().__init__() @@ -394,6 +475,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -406,6 +488,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) @@ -416,6 +499,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, @@ -429,6 +513,9 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size @@ -453,6 +540,7 @@ def forward( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + position_ids=position_ids, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) @@ -472,6 +560,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + position_ids=position_ids, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, @@ -751,7 +840,6 @@ def _init_weights(self, module): """ -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->ParlerTTS class ParlerTTSDecoder(ParlerTTSPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`] @@ -772,10 +860,11 @@ def __init__(self, config: ParlerTTSDecoderConfig): [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] ) - self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( - config.max_position_embeddings, - config.hidden_size, - ) + if not config.rope_embeddings: + self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.hidden_size, + ) self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(config.hidden_size) @@ -803,6 +892,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -878,12 +968,21 @@ def forward( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) - # embed positions - # TODO: As it is, the masked ids from the prompt will still count in the positions embeddings - # maybe should modify position embeddings - positions = self.embed_positions(inputs_embeds, past_key_values_length) - - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + if not self.config.rope_embeddings: + # embed positions + # TODO: As it is, the masked ids from the prompt will still count in the positions embeddings + # maybe should modify position embeddings + positions = self.embed_positions(inputs_embeds, past_key_values_length) + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + else: + hidden_states = inputs_embeds + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, inputs_embeds.shape[1] + past_key_values_length, dtype=torch.long, + device=device + ) + position_ids = position_ids.unsqueeze(0) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -923,6 +1022,7 @@ def forward( decoder_layer.forward, hidden_states, attention_mask, + position_ids, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -935,6 +1035,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, + position_ids=position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1004,6 +1105,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, @@ -1028,6 +1130,7 @@ def forward( decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, prompt_hidden_states=prompt_hidden_states, @@ -1058,7 +1161,6 @@ def forward( "The Parler-TTS decoder model with a language modelling head on top.", MUSICGEN_START_DOCSTRING, ) -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->ParlerTTS class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): def __init__(self, config: ParlerTTSDecoderConfig): super().__init__(config) @@ -1097,6 +1199,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, @@ -1124,6 +1227,7 @@ def forward( outputs = self.model( input_ids, attention_mask=attention_mask, + position_ids=position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, prompt_hidden_states=prompt_hidden_states, @@ -1228,8 +1332,16 @@ def prepare_inputs_for_generation( [prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0 ) + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: input_ids = input_ids[:, -1:] + if position_ids is not None: + position_ids = position_ids[:, -input_ids.shape[1]:] # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None @@ -1237,6 +1349,7 @@ def prepare_inputs_for_generation( return { "input_ids": input_ids, "attention_mask": attention_mask, + "position_ids": position_ids, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "prompt_hidden_states": prompt_hidden_states, @@ -1931,6 +2044,7 @@ def forward( prompt_input_ids: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -2041,6 +2155,7 @@ def forward( decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, prompt_hidden_states=prompt_hidden_states, @@ -2109,6 +2224,12 @@ def prepare_inputs_for_generation( if prompt_attention_mask is not None: prompt_attention_mask = prompt_attention_mask.repeat((2, 1)) + decoder_position_ids = kwargs.get("decoder_position_ids", None) + if decoder_attention_mask is not None and decoder_position_ids is None: + # create position_ids on the fly for batch generation + decoder_position_ids = decoder_attention_mask.long().cumsum(-1) - 1 + decoder_position_ids.masked_fill_(decoder_attention_mask == 0, 1) + if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -2121,6 +2242,9 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + if decoder_position_ids is not None: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None @@ -2129,6 +2253,7 @@ def prepare_inputs_for_generation( "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, + "decoder_position_ids": decoder_position_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, From fd16411f6cf7d86d805c59c19d5b398c779be7b6 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 17 May 2024 15:20:00 +0100 Subject: [PATCH 02/62] don't include padding in rope --- parler_tts/modeling_parler_tts.py | 66 ++++++++++++++++--------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 9465fd3..a3fcf55 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -357,6 +357,8 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling + query_states = self._shape(query_states, tgt_len, bsz) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -384,6 +386,10 @@ def forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + if self.config.rope_embeddings: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -394,11 +400,6 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - query_states = self._shape(query_states, tgt_len, bsz) - if self.config.rope_embeddings: - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) @@ -957,16 +958,6 @@ def forward( ) input_shape = inputs_embeds.size()[:-1] - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) if not self.config.rope_embeddings: # embed positions @@ -976,16 +967,38 @@ def forward( hidden_states = inputs_embeds + positions.to(inputs_embeds.device) else: hidden_states = inputs_embeds + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, inputs_embeds.shape[1] + past_key_values_length, dtype=torch.long, - device=device - ) - position_ids = position_ids.unsqueeze(0) + if attention_mask is not None: + # masked ids will **not** count in the position embeddings + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = torch.arange( + past_key_values_length, input_shape[1] + past_key_values_length, + dtype=torch.long, + device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + + # Some generation methods already pass only the last input ID + if position_ids.shape[1] > input_shape[1]: + position_ids = position_ids[:, -input_shape[1]:] hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -1248,7 +1261,6 @@ def forward( loss = None if labels is not None: - loss = torch.zeros([], device=self.device) # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels logits = lm_logits[:, :, -labels.shape[1] :] @@ -2224,12 +2236,6 @@ def prepare_inputs_for_generation( if prompt_attention_mask is not None: prompt_attention_mask = prompt_attention_mask.repeat((2, 1)) - decoder_position_ids = kwargs.get("decoder_position_ids", None) - if decoder_attention_mask is not None and decoder_position_ids is None: - # create position_ids on the fly for batch generation - decoder_position_ids = decoder_attention_mask.long().cumsum(-1) - 1 - decoder_position_ids.masked_fill_(decoder_attention_mask == 0, 1) - if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -2242,9 +2248,6 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - if decoder_position_ids is not None: - decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] - # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None @@ -2253,7 +2256,6 @@ def prepare_inputs_for_generation( "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, - "decoder_position_ids": decoder_position_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, From 1344de2400e620e945d71dc0f35e66c11f05bd6b Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Fri, 17 May 2024 16:07:19 +0100 Subject: [PATCH 03/62] possibly use cross-attn for prompt --- parler_tts/configuration_parler_tts.py | 5 ++++- parler_tts/modeling_parler_tts.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index dba2b8b..ad7d7b3 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -148,6 +148,8 @@ class ParlerTTSConfig(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 1024): Vocabulary size of the prompt token ids. Defines the number of different tokens that can be represented by the `prompt_inputs_ids`. + prompt_cross_attention (`bool`, *optional*, defaults to `False`): + Whether to use cross-attention conditioning for the prompt (as well as the description). kwargs (*optional*): Dictionary of keyword arguments. Notably: @@ -198,7 +200,7 @@ class ParlerTTSConfig(PretrainedConfig): model_type = "parler_tts" is_composition = True - def __init__(self, vocab_size=1024, **kwargs): + def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs): super().__init__(**kwargs) if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") @@ -212,6 +214,7 @@ def __init__(self, vocab_size=1024, **kwargs): decoder_config = kwargs.pop("decoder") self.vocab_size = vocab_size + self.prompt_cross_attention = prompt_cross_attention self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.decoder = ParlerTTSDecoderConfig(**decoder_config) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index a3fcf55..635afc6 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1758,6 +1758,12 @@ def __init__( # prompt embeddings self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) + if config.prompt_cross_attention: + self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( + config.decoder.max_position_embeddings, + config.decoder.hidden_size, + ) + if self.text_encoder.get_output_embeddings() is not None: raise ValueError( f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" @@ -2138,6 +2144,19 @@ def forward( if prompt_input_ids is not None: prompt_hidden_states = self.embed_prompts(prompt_input_ids) + if prompt_hidden_states is not None and self.config.prompt_cross_attention: + # add sinusoidal positional embedding + positions = self.embed_positions(prompt_hidden_states, 0) + prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) + + # concatenate text description states with prompt description states + encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) + if prompt_attention_mask is not None: + attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) + + prompt_hidden_states = None + prompt_attention_mask = None + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id @@ -2251,6 +2270,9 @@ def prepare_inputs_for_generation( # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None + if self.config.prompt_cross_attention: + prompt_attention_mask = None + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, From e3fe8432c709294d46a3cb7d234712981df779f8 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Sat, 18 May 2024 16:18:43 +0100 Subject: [PATCH 04/62] fix rope --- parler_tts/modeling_parler_tts.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 635afc6..17d63fb 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -385,11 +385,6 @@ def forward( # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.config.rope_embeddings: - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) - if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -400,6 +395,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + if self.config.rope_embeddings: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) From 4f27ff7517b4b96fd6e9977872c59bc7a5e6e0ca Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Sat, 18 May 2024 16:39:46 +0100 Subject: [PATCH 05/62] fix cross-attn --- parler_tts/modeling_parler_tts.py | 34 +++---------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 17d63fb..860f976 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -929,33 +929,6 @@ def forward( if prompt_hidden_states is not None: inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) - # As it is, the masked ids from the prompt will still count in the positions embeddings - if prompt_attention_mask is not None and attention_mask is not None: - attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) - elif prompt_attention_mask is not None: - logger.warning_once( - "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." - ) - if past_key_values is None: - attention_mask = torch.cat( - [ - prompt_attention_mask, - torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype), - ], - dim=1, - ) - else: - generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 - attention_mask = torch.cat( - [ - prompt_attention_mask, - torch.ones( - (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype - ), - ], - dim=1, - ) - input_shape = inputs_embeds.size()[:-1] if not self.config.rope_embeddings: @@ -2151,6 +2124,8 @@ def forward( # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) if prompt_attention_mask is not None: + if attention_mask is None: + attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) prompt_hidden_states = None @@ -2267,10 +2242,7 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] # we only want to use prompt signal in the 1st generation step but keeping the attention mask - prompt_hidden_states = None - - if self.config.prompt_cross_attention: - prompt_attention_mask = None + prompt_hidden_states = prompt_hidden_states if self.config.prompt_cross_attention else None return { "input_ids": None, # encoder_outputs is defined. input_ids not needed From c8887f20514b43de43f116f7b4bd1d72a6832013 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Sat, 18 May 2024 16:40:34 +0100 Subject: [PATCH 06/62] fix self-attn --- parler_tts/modeling_parler_tts.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 860f976..db0e0e7 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -929,6 +929,33 @@ def forward( if prompt_hidden_states is not None: inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) + # As it is, the masked ids from the prompt will still count in the positions embeddings + if prompt_attention_mask is not None and attention_mask is not None: + attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) + elif prompt_attention_mask is not None: + logger.warning_once( + "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." + ) + if past_key_values is None: + attention_mask = torch.cat( + [ + prompt_attention_mask, + torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype), + ], + dim=1, + ) + else: + generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 + attention_mask = torch.cat( + [ + prompt_attention_mask, + torch.ones( + (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype + ), + ], + dim=1, + ) + input_shape = inputs_embeds.size()[:-1] if not self.config.rope_embeddings: From ba0bfcab71095a8829a09fee081702b6065dd662 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Mon, 20 May 2024 11:35:08 +0100 Subject: [PATCH 07/62] fix dummy model --- helpers/model_init_scripts/init_dummy_model_with_encodec.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/helpers/model_init_scripts/init_dummy_model_with_encodec.py b/helpers/model_init_scripts/init_dummy_model_with_encodec.py index 32242b4..4e26089 100644 --- a/helpers/model_init_scripts/init_dummy_model_with_encodec.py +++ b/helpers/model_init_scripts/init_dummy_model_with_encodec.py @@ -58,4 +58,7 @@ model.generation_config.do_sample = True # True model.generation_config.guidance_scale = 1 # 3.0 + model.config.pad_token_id = encodec_vocab_size + model.config.decoder_start_token_id = encodec_vocab_size + 1 + model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) From 387de15e70437abf976091ba70742aeceafd2426 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 21 May 2024 16:51:50 +0100 Subject: [PATCH 08/62] clean-up rope --- parler_tts/modeling_parler_tts.py | 40 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index db0e0e7..1f368d3 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -265,18 +265,13 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, is_cross_attention, unsqueeze_dim=1): +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. + x (`torch.Tensor`): The tensor over which to apply the rope embeddings cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - is_cross_attention (`bool`): - Whether this is a cross-attention layer. If so, we don't apply ROPE to the key-states, since we assume - the encoder hidden-states have been computed with suitable positional encoding. Adding ROPE on-top of - such embeddings is hypothesised to disrupt the original positional encodings. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -289,9 +284,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, is_cross_attention, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) if not is_cross_attention else k - return q_embed, k_embed + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed class ParlerTTSAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -327,6 +321,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.rope_embeddings = config.rope_embeddings if config.rope_embeddings: self.rotary_emb = ParlerTTSRotaryEmbedding( self.head_dim, @@ -359,6 +354,10 @@ def forward( query_states = self.q_proj(hidden_states) * self.scaling query_states = self._shape(query_states, tgt_len, bsz) + if self.rope_embeddings: + cos, sin = self.rotary_emb(query_states, position_ids) + query_states = apply_rotary_pos_emb(query_states, cos, sin) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -372,19 +371,22 @@ def forward( key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: - # cross_attentions + # cross_attentions - don't apply rope to the key states, since they already have positional embeddings applied key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -395,10 +397,6 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - if self.config.rope_embeddings: - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, is_cross_attention) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) @@ -433,10 +431,6 @@ def forward( attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: @@ -860,6 +854,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] ) + self.rope_embeddings = config.rope_embeddings if not config.rope_embeddings: self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -958,7 +953,7 @@ def forward( input_shape = inputs_embeds.size()[:-1] - if not self.config.rope_embeddings: + if not self.rope_embeddings: # embed positions # TODO: As it is, the masked ids from the prompt will still count in the positions embeddings # maybe should modify position embeddings @@ -1757,6 +1752,7 @@ def __init__( # prompt embeddings self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) + self.prompt_cross_attention = config.prompt_cross_attention if config.prompt_cross_attention: self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( config.decoder.max_position_embeddings, @@ -2143,7 +2139,7 @@ def forward( if prompt_input_ids is not None: prompt_hidden_states = self.embed_prompts(prompt_input_ids) - if prompt_hidden_states is not None and self.config.prompt_cross_attention: + if prompt_hidden_states is not None and self.prompt_cross_attention: # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) @@ -2269,7 +2265,7 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] # we only want to use prompt signal in the 1st generation step but keeping the attention mask - prompt_hidden_states = prompt_hidden_states if self.config.prompt_cross_attention else None + prompt_hidden_states = prompt_hidden_states if self.prompt_cross_attention else None return { "input_ids": None, # encoder_outputs is defined. input_ids not needed From e0500d0da19037062b4ce9c778433e7a5f78bbfc Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 23 May 2024 14:14:27 +0200 Subject: [PATCH 09/62] first gqa implementation --- parler_tts/configuration_parler_tts.py | 13 +++++++ parler_tts/modeling_parler_tts.py | 47 +++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index 5d631c4..bb29195 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -47,6 +47,14 @@ class ParlerTTSDecoderConfig(PretrainedConfig): Number of decoder layers. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. ffn_dim (`int`, *optional*, defaults to 4096): Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): @@ -86,6 +94,7 @@ def __init__( num_hidden_layers=24, ffn_dim=4096, num_attention_heads=16, + num_key_value_heads=None, layerdrop=0.0, use_cache=True, activation_function="gelu", @@ -108,6 +117,10 @@ def __init__( self.ffn_dim = ffn_dim self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 1437465..c9d7445 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -138,6 +138,18 @@ def build_delay_pattern_mask( input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) return input_ids, pattern_mask +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + @dataclass class ParlerTTSUnconditionalInput(ModelOutput): @@ -223,14 +235,14 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS class ParlerTTSAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" + """Multi-headed attention from 'Attention Is All You Need' paper. Modified to use GQA and MQA.""" def __init__( self, embed_dim: int, num_heads: int, + num_key_value_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, @@ -242,6 +254,8 @@ def __init__( self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: @@ -253,14 +267,17 @@ def __init__( self.is_decoder = is_decoder self.is_causal = is_causal - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + def _shape_query(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2).contiguous() + def forward( self, hidden_states: torch.Tensor, @@ -294,18 +311,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -316,9 +333,13 @@ def forward( # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = self._shape_query(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -391,6 +412,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.self_attn = ParlerTTSAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, @@ -403,6 +425,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.encoder_attn = ParlerTTSAttention( self.embed_dim, config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, From 3342e2e50c79c2dbe3a16a80295f6444ca0df6c2 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 23 May 2024 14:59:18 +0200 Subject: [PATCH 10/62] fix wer eval --- training/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/eval.py b/training/eval.py index 57cb32f..67e01e3 100644 --- a/training/eval.py +++ b/training/eval.py @@ -47,12 +47,12 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s normalized_references = [] for pred, ref in zip(transcriptions, prompts): - normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer + normalizer = english_normalizer if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english" else basic_normalizer norm_ref = normalizer(ref) if len(norm_ref) > 0: norm_pred = normalizer(pred["text"]) normalized_predictions.append(norm_pred) - normalized_references.append(norm_pred) + normalized_references.append(norm_ref) word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) From b171d98b005d1fdb722c84d8a7c0aa37aee149cf Mon Sep 17 00:00:00 2001 From: sang-nguyen-ts Date: Wed, 29 May 2024 12:00:25 +0700 Subject: [PATCH 11/62] feat: add flash attention and spda --- parler_tts/configuration_parler_tts.py | 18 ++ parler_tts/modeling_parler_tts.py | 421 ++++++++++++++++++++++++- 2 files changed, 423 insertions(+), 16 deletions(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index 5d631c4..c1a25df 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -236,3 +236,21 @@ def from_sub_models_config( # This is a property because you might want to change the codec model on the fly def sampling_rate(self): return self.audio_encoder.sampling_rate + + # Copy from musicgen + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + self.decoder._attn_implementation = value diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 1437465..8bfd282 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -28,7 +28,7 @@ from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,6 +43,9 @@ logging, replace_return_docstrings, ) +import torch.nn.functional as F +from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .dac_wrapper import DACConfig, DACModel @@ -56,6 +59,13 @@ logger = logging.get_logger(__name__) + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +else: + logger.warn("Flash attention 2 is not installed") + _CONFIG_FOR_DOC = "ParlerTTSConfig" _CHECKPOINT_FOR_DOC = "facebook/parler_tts-small" @@ -139,6 +149,8 @@ def build_delay_pattern_mask( return input_ids, pattern_mask + + @dataclass class ParlerTTSUnconditionalInput(ModelOutput): """ @@ -381,6 +393,349 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + + +# Copied from transformers.models.musicgen.modeling_bart.MusicgenFlashAttention2 with Musicgen->ParlerTTS +class ParlerTTSFlashAttention2(ParlerTTSAttention): + """ + ParlerTTS flash attention module. This module inherits from `ParlerTTSAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MusicgenFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + +# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen +class ParlerTTSSdpaAttention(ParlerTTSAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + + +PARLERTTS_ATTENTION_CLASSES = { + "eager": ParlerTTSAttention, + "sdpa": ParlerTTSSdpaAttention, + "flash_attention_2": ParlerTTSFlashAttention2, +} # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS class ParlerTTSDecoderLayer(nn.Module): @@ -388,11 +743,12 @@ def __init__(self, config: ParlerTTSDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = ParlerTTSAttention( + self.self_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, bias=False, ) self.dropout = config.dropout @@ -400,12 +756,13 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = ParlerTTSAttention( + self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.num_attention_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) @@ -449,6 +806,8 @@ def forward( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple + + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, @@ -512,8 +871,11 @@ class ParlerTTSPreTrainedModel(PreTrainedModel): config_class = ParlerTTSDecoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] + def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, (nn.Linear, nn.Conv1d)): @@ -779,7 +1141,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(config.hidden_size) - + self.attn_implementation = config._attn_implementation self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -867,16 +1229,42 @@ def forward( ) input_shape = inputs_embeds.size()[:-1] - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if self.attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + + + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if self.attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions # TODO: As it is, the masked ids from the prompt will still count in the positions embeddings @@ -1554,6 +1942,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__( self, @@ -1586,7 +1976,6 @@ def __init__( if text_encoder is None: from transformers.models.auto.modeling_auto import AutoModelForTextEncoding - text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) if audio_encoder is None: @@ -1983,6 +2372,7 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + if encoder_outputs is None: encoder_outputs = self.text_encoder( input_ids=input_ids, @@ -2617,10 +3007,10 @@ def generate( else: output_ids = outputs - # apply the pattern mask to the final ids + # Apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) - # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask + # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask _, mask = self.decoder.build_delay_pattern_mask( input_ids, bos_token_id=generation_config.bos_token_id, @@ -2659,13 +3049,12 @@ def generate( output_values.append(sample.transpose(0, 2)) else: output_values.append(torch.zeros((1, 1, 1)).to(self.device)) - # TODO: we should keep track of output length as well. Not really straightfoward tbh + # TODO: we should keep track of output length as well. Not really straightforward tbh output_values = ( torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0) .squeeze(-1) .squeeze(-1) ) - if generation_config.return_dict_in_generate: outputs.sequences = output_values return outputs From 9e915aa12557644f656363e7d22a8af0e874167b Mon Sep 17 00:00:00 2001 From: sang-nguyen-ts Date: Wed, 29 May 2024 12:00:40 +0700 Subject: [PATCH 12/62] chore: add README for flash attention --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index c0d5bb5..eaec09f 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,16 @@ if torch.xpu.is_available(): torch_dtype = torch.float16 if device != "cpu" else torch.float32 model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) + +# # Use with flash attention +# model = ParlerTTSForConditionalGeneration.from_pretrained( +# repo_id, attn_implementation="flash_attention_2", torch_dtype=torch.float16 +# ).to(device, dtype=torch_dtype) + + +model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) + + tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") prompt = "Hey, how are you doing today?" From 1eeca66b87cb91630c8fe783b0e59cc4aed22e9e Mon Sep 17 00:00:00 2001 From: sang-nguyen-ts Date: Wed, 29 May 2024 12:01:05 +0700 Subject: [PATCH 13/62] chore: add benchmark script --- helpers/benchmark/dataset.py | 7 +++ helpers/benchmark/parler_flash_attention.py | 60 +++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 helpers/benchmark/dataset.py create mode 100644 helpers/benchmark/parler_flash_attention.py diff --git a/helpers/benchmark/dataset.py b/helpers/benchmark/dataset.py new file mode 100644 index 0000000..f923b6b --- /dev/null +++ b/helpers/benchmark/dataset.py @@ -0,0 +1,7 @@ +from datasets import load_dataset + +dataset = load_dataset("parler-tts/libritts_r_tags_tagged_10k_generated", 'clean') + +PROMPTS = dataset['test.clean']['text'] +DESCRIPTIONS = dataset['test.clean']['text_description'] + diff --git a/helpers/benchmark/parler_flash_attention.py b/helpers/benchmark/parler_flash_attention.py new file mode 100644 index 0000000..be9df11 --- /dev/null +++ b/helpers/benchmark/parler_flash_attention.py @@ -0,0 +1,60 @@ +import torch +from parler_tts import ParlerTTSForConditionalGeneration +from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed +from tqdm import tqdm +from dataset import PROMPTS, DESCRIPTIONS +import time + +model = ParlerTTSForConditionalGeneration.from_pretrained( + "parler-tts/parler-tts-mini-expresso", + attn_implementation="eager", + torch_dtype=torch.float16 +).to("cuda:0") + + +for i in range(3): + print(f"Wramming up decoder") + z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) + model.audio_encoder.model.decode(z) + + +tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") + +def generate_speech(prompt, description): + input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") + prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") + + generation_config = model.generation_config + + # Generate first second + generation_config.max_length = 86 # default 2580. WTF + + _ = model.generate(input_ids=input_ids, + prompt_input_ids=prompt_input_ids, + generation_config=generation_config, + use_cache=True, + past_key_values = None, + ) + + +if __name__ == "__main__": + NUM_SAMPLE = 20 + + latencies = [] + + for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): + prompt = PROMPTS[i] + description = DESCRIPTIONS[i] + + start = time.perf_counter() + + _ = generate_speech(prompt, description) + + latencies.append(time.perf_counter() - start) + + + print(f"AVG latency = {sum(latencies) / len(latencies)}") + + + + From 69521dd4d523e856f2f69210ff0225beed2d6a52 Mon Sep 17 00:00:00 2001 From: sang-nguyen-ts Date: Wed, 29 May 2024 12:02:02 +0700 Subject: [PATCH 14/62] chore: add benchmark attention approach --- helpers/benchmark/parler_eager_attention.py | 60 +++++++++++++++++++++ helpers/benchmark/parler_flash_attention.py | 2 +- helpers/benchmark/parler_sdpa_attention.py | 60 +++++++++++++++++++++ 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 helpers/benchmark/parler_eager_attention.py create mode 100644 helpers/benchmark/parler_sdpa_attention.py diff --git a/helpers/benchmark/parler_eager_attention.py b/helpers/benchmark/parler_eager_attention.py new file mode 100644 index 0000000..be9df11 --- /dev/null +++ b/helpers/benchmark/parler_eager_attention.py @@ -0,0 +1,60 @@ +import torch +from parler_tts import ParlerTTSForConditionalGeneration +from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed +from tqdm import tqdm +from dataset import PROMPTS, DESCRIPTIONS +import time + +model = ParlerTTSForConditionalGeneration.from_pretrained( + "parler-tts/parler-tts-mini-expresso", + attn_implementation="eager", + torch_dtype=torch.float16 +).to("cuda:0") + + +for i in range(3): + print(f"Wramming up decoder") + z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) + model.audio_encoder.model.decode(z) + + +tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") + +def generate_speech(prompt, description): + input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") + prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") + + generation_config = model.generation_config + + # Generate first second + generation_config.max_length = 86 # default 2580. WTF + + _ = model.generate(input_ids=input_ids, + prompt_input_ids=prompt_input_ids, + generation_config=generation_config, + use_cache=True, + past_key_values = None, + ) + + +if __name__ == "__main__": + NUM_SAMPLE = 20 + + latencies = [] + + for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): + prompt = PROMPTS[i] + description = DESCRIPTIONS[i] + + start = time.perf_counter() + + _ = generate_speech(prompt, description) + + latencies.append(time.perf_counter() - start) + + + print(f"AVG latency = {sum(latencies) / len(latencies)}") + + + + diff --git a/helpers/benchmark/parler_flash_attention.py b/helpers/benchmark/parler_flash_attention.py index be9df11..b0a9e9f 100644 --- a/helpers/benchmark/parler_flash_attention.py +++ b/helpers/benchmark/parler_flash_attention.py @@ -7,7 +7,7 @@ model = ParlerTTSForConditionalGeneration.from_pretrained( "parler-tts/parler-tts-mini-expresso", - attn_implementation="eager", + attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).to("cuda:0") diff --git a/helpers/benchmark/parler_sdpa_attention.py b/helpers/benchmark/parler_sdpa_attention.py new file mode 100644 index 0000000..73470f5 --- /dev/null +++ b/helpers/benchmark/parler_sdpa_attention.py @@ -0,0 +1,60 @@ +import torch +from parler_tts import ParlerTTSForConditionalGeneration +from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed +from tqdm import tqdm +from dataset import PROMPTS, DESCRIPTIONS +import time + +model = ParlerTTSForConditionalGeneration.from_pretrained( + "parler-tts/parler-tts-mini-expresso", + attn_implementation="sdpa", + torch_dtype=torch.float16 +).to("cuda:0") + + +for i in range(3): + print(f"Wramming up decoder") + z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) + model.audio_encoder.model.decode(z) + + +tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") + +def generate_speech(prompt, description): + input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") + prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") + + generation_config = model.generation_config + + # Generate first second + generation_config.max_length = 86 # default 2580. WTF + + _ = model.generate(input_ids=input_ids, + prompt_input_ids=prompt_input_ids, + generation_config=generation_config, + use_cache=True, + past_key_values = None, + ) + + +if __name__ == "__main__": + NUM_SAMPLE = 20 + + latencies = [] + + for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): + prompt = PROMPTS[i] + description = DESCRIPTIONS[i] + + start = time.perf_counter() + + _ = generate_speech(prompt, description) + + latencies.append(time.perf_counter() - start) + + + print(f"AVG latency = {sum(latencies) / len(latencies)}") + + + + From 7539000f017821af1a54e59a60c26e337248d4fb Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Thu, 30 May 2024 11:14:40 +0000 Subject: [PATCH 15/62] multi node and fix wer and fix compile --- training/data.py | 8 +++++--- training/run_parler_tts_training.py | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/training/data.py b/training/data.py index f5c9862..9603b35 100644 --- a/training/data.py +++ b/training/data.py @@ -30,6 +30,8 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> # different padding methods audios = [feature[self.audio_column_name]["array"] for feature in features] len_audio = [len(audio) for audio in audios] + if self.max_length is not None: + audios = [audio[:min(l, self.max_length)] for audio, l in zip(audios, len_audio)] # since resampling has already been performed in the 'load_multiple_datasets' function, # a fixed sampling_rate(44100hz) is passed to the feature_extractor. @@ -81,7 +83,7 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> # (bsz, seq_len, num_codebooks) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) if self.audio_max_length is not None and self.padding == "max_length": - labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0))) + labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100) input_ids = [{"input_ids": feature["input_ids"]} for feature in features] @@ -206,7 +208,7 @@ def load_multiple_datasets( all_datasets = [] # iterate over the datasets we want to interleave for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): dataset = load_dataset( dataset_dict["name"], dataset_dict["config"], @@ -304,7 +306,7 @@ def load_multiple_datasets( seed=seed, ) else: - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): interleaved_dataset = concatenate_datasets(all_datasets) return interleaved_dataset diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 22e091f..ce584b2 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -97,7 +97,7 @@ def main(): padding = "max_length" if data_args.pad_to_max_length else "longest" ####### A. Preparation - kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] + kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120))] accelerator = Accelerator( gradient_accumulation_steps=training_args.gradient_accumulation_steps, @@ -341,6 +341,7 @@ def main(): model.freeze_encoders(model_args.freeze_text_encoder) # Test all gather - used for warmout and avoiding timeout + logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True) test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device) gathered_tensor = accelerator.gather(test_tensor) print("gathered_tensor", gathered_tensor) @@ -349,7 +350,7 @@ def main(): if not dataset_was_precomputed: # Filter on text length if description_column_name is not None and data_args.max_text_length is not None: - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): # filter description that is shorter than max_text_length raw_datasets = raw_datasets.filter( lambda x: len(x) < data_args.max_text_length, @@ -367,7 +368,7 @@ def pass_through_processors(description, prompt): return batch - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): # this is a trick to avoid to rewrite the entire audio column which takes ages vectorized_datasets = raw_datasets.map( pass_through_processors, @@ -430,7 +431,7 @@ def apply_audio_decoder(batch): generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) generate_labels = accelerator.gather_for_metrics(generate_labels) - if accelerator.is_main_process: + if accelerator.is_local_main_process: lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16) rat = generate_labels["ratio"].cpu().squeeze() lens = generate_labels["len_audio"].cpu().squeeze() @@ -448,11 +449,11 @@ def apply_audio_decoder(batch): os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers, ) - accelerator.wait_for_everyone() del all_generated_labels + accelerator.wait_for_everyone() tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) def postprocess_dataset(labels): @@ -483,7 +484,7 @@ def postprocess_dataset(labels): output = {"labels": labels[:, 1:]} return output - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): vectorized_datasets[split] = vectorized_datasets[split].map( postprocess_dataset, num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. @@ -494,7 +495,7 @@ def postprocess_dataset(labels): accelerator.free_memory() del generate_labels, all_lens - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets. # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets. @@ -510,7 +511,7 @@ def is_audio_in_length_range(length): ) if description_column_name is not None and data_args.max_description_token_length is not None: - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): # filter description that is shorter than max_text_length vectorized_datasets = vectorized_datasets.filter( lambda x: len(x) < data_args.max_description_token_length, @@ -519,7 +520,7 @@ def is_audio_in_length_range(length): ) if data_args.max_prompt_token_length is not None: - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): # filter description that is shorter than max_text_length vectorized_datasets = vectorized_datasets.filter( lambda x: len(x) < data_args.max_prompt_token_length, @@ -538,7 +539,7 @@ def is_audio_in_length_range(length): audio_max_length = None if padding == "max_length": audio_max_length = max(vectorized_datasets["train"]["target_length"]) - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): max_sample = vectorized_datasets["train"].filter( lambda x: x == audio_max_length, num_proc=num_workers, @@ -551,7 +552,7 @@ def is_audio_in_length_range(length): def add_target_lengths(target_length, prompt, description): return {"target_length": target_length + len(prompt) + len(description)} - with accelerator.main_process_first(): + with accelerator.local_main_process_first(): vectorized_datasets = vectorized_datasets.map( add_target_lengths, num_proc=num_workers, @@ -901,6 +902,7 @@ def generate_step(batch): commit_message=f"Saving train state of step {cur_step}", run_as_future=True, ) + accelerator.wait_for_everyone() if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): train_time += time.time() - train_start From 76dbe87bb456fee62f1f04715e7e51b4f00a01d5 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 30 May 2024 13:32:16 +0200 Subject: [PATCH 16/62] Update modeling_parler_tts.py --- parler_tts/modeling_parler_tts.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 8bfd282..824c420 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -701,6 +701,12 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -710,7 +716,7 @@ def forward( attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): From fc79c06d8fd5589dc58da8c8510ad49e2bb46f32 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 31 May 2024 15:19:14 +0200 Subject: [PATCH 17/62] fix FA2, SDPA and add cross-attn MHA and attention type forcing --- parler_tts/configuration_parler_tts.py | 12 +++- parler_tts/modeling_parler_tts.py | 98 +++++++++++++++++--------- 2 files changed, 75 insertions(+), 35 deletions(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index 5480add..689c7a9 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -55,6 +55,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + num_cross_attention_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers. + If it is not specified, will default to `num_attention_heads`. ffn_dim (`int`, *optional*, defaults to 4096): Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): @@ -86,6 +89,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig): Whether to use ROPE or absolute positional embeddings. rope_theta (`float`, *optional*, defaults to 100000.0): The base period of the RoPE embeddings. + cross_attention_implementation_strategy (`str`, *optional*): + If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation. """ model_type = "parler_tts_decoder" @@ -99,6 +104,7 @@ def __init__( ffn_dim=4096, num_attention_heads=16, num_key_value_heads=None, + num_cross_attention_key_value_heads=None, layerdrop=0.0, use_cache=True, activation_function="gelu", @@ -115,6 +121,7 @@ def __init__( tie_word_embeddings=False, rope_embeddings=False, rope_theta=10_000.0, + cross_attention_implementation_strategy=None, **kwargs, ): self.vocab_size = vocab_size @@ -125,8 +132,10 @@ def __init__( self.num_attention_heads = num_attention_heads if num_key_value_heads is None: num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads + if num_cross_attention_key_value_heads is None: + num_cross_attention_key_value_heads = num_attention_heads + self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout @@ -138,6 +147,7 @@ def __init__( self.num_codebooks = num_codebooks self.rope_embeddings = rope_embeddings self.rope_theta = rope_theta + self.cross_attention_implementation_strategy = cross_attention_implementation_strategy super().__init__( pad_token_id=pad_token_id, diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index cf1727c..c85545d 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -505,7 +505,7 @@ def _get_unpad_data(attention_mask): -# Copied from transformers.models.musicgen.modeling_bart.MusicgenFlashAttention2 with Musicgen->ParlerTTS +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenFlashAttention2 with Musicgen->ParlerTTS class ParlerTTSFlashAttention2(ParlerTTSAttention): """ ParlerTTS flash attention module. This module inherits from `ParlerTTSAttention` as the weights of the module stays @@ -522,23 +522,19 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # MusicgenFlashAttention2 attention does not support output_attentions + # ParlerTTSFlashAttention2 attention does not support output_attentions if output_attentions: - raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions") + raise ValueError("ParlerTTSFlashAttention2 attention does not support output_attentions") # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -547,7 +543,12 @@ def forward( bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + if self.rope_embeddings: + cos, sin = self.rotary_emb(query_states.transpose(1,2), position_ids) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -558,22 +559,26 @@ def forward( and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) + key_states = past_key_value[0].transpose(1,2) + value_states = past_key_value[1].transpose(1,2) elif is_cross_attention: # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states + key_states = torch.cat([past_key_value[0].transpose(1,2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1,2), value_states], dim=1) else: # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -583,9 +588,9 @@ def forward( # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + past_key_value = (key_states.transpose(1,2), value_states.transpose(1,2)) - kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[-3] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] @@ -734,6 +739,7 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -741,7 +747,7 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( - "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + "ParlerTTSModel is using ParlerTTSSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( @@ -761,6 +767,12 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) + query_states = self._shape_query(query_states, tgt_len, bsz) + + if self.rope_embeddings: + cos, sin = self.rotary_emb(query_states, position_ids) + query_states = apply_rotary_pos_emb(query_states, cos, sin) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -774,19 +786,22 @@ def forward( key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # cross_attentions - don't apply rope to the key states, since they already have positional embeddings applied + key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -798,7 +813,11 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - query_states = self._shape(query_states, tgt_len, bsz) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = self._shape_query(query_states, tgt_len, bsz) # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. @@ -862,10 +881,15 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation]( + cross_attn_implementation = config._attn_implementation + if config.cross_attention_implementation_strategy == "always_eager": + cross_attn_implementation = "eager" + elif config.cross_attention_implementation_strategy == "always_sdpa": + cross_attn_implementation = "sdpa" + self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[cross_attn_implementation]( self.embed_dim, config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, + num_key_value_heads=config.num_cross_attention_key_value_heads, dropout=config.attention_dropout, is_decoder=True, bias=False, @@ -2563,12 +2587,15 @@ def forward( # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) + + if prompt_attention_mask is not None and attention_mask is None: + attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) + elif attention_mask is not None and prompt_attention_mask is None: + prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) if prompt_attention_mask is not None: - if attention_mask is None: - attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) prompt_hidden_states = None @@ -3180,6 +3207,9 @@ def generate( else: output_ids = outputs + # TODO: remove + return + # Apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) From 780828568535e80a817742f1b2dcbd555629f5ac Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 4 Jun 2024 17:04:48 +0200 Subject: [PATCH 18/62] better cross_attention key values number of heads default + add training arguments for attn implementation --- parler_tts/configuration_parler_tts.py | 4 ++-- training/arguments.py | 12 ++++++++++++ training/run_parler_tts_training.py | 2 ++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index 689c7a9..c390d8a 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -57,7 +57,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig): `num_attention_heads`. num_cross_attention_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers. - If it is not specified, will default to `num_attention_heads`. + If it is not specified, will default to `num_cross_attention_key_value_heads`. ffn_dim (`int`, *optional*, defaults to 4096): Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): @@ -134,7 +134,7 @@ def __init__( num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads if num_cross_attention_key_value_heads is None: - num_cross_attention_key_value_heads = num_attention_heads + num_cross_attention_key_value_heads = num_key_value_heads self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads self.dropout = dropout self.attention_dropout = attention_dropout diff --git a/training/arguments.py b/training/arguments.py index b806679..ff63bb7 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -78,6 +78,18 @@ class ModelArguments: "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models" }, ) + attn_implementation: str = field( + default="eager", + metadata={ + "help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`" + }, + ) + cross_attention_implementation_strategy: str = field( + default=None, + metadata={ + "help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation." + }, + ) @dataclass diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index ce584b2..f31c4f4 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -301,6 +301,7 @@ def main(): "decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else config.decoder_start_token_id, + "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy if model_args.cross_attention_implementation_strategy is not None else None, } ) @@ -311,6 +312,7 @@ def main(): config=config, token=data_args.token, trust_remote_code=data_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, ) # enable gradient checkpointing if necessary From 0ce0df2caa3ee17fd21a38305ddd07fd263c1fb4 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 4 Jun 2024 20:13:44 +0200 Subject: [PATCH 19/62] fix audio padding when torch compile or pad_to_max_length=True --- training/run_parler_tts_training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index f31c4f4..a0a73ed 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -413,7 +413,10 @@ def apply_audio_decoder(batch): output["len_audio"] = len_audio # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) output["labels"] = labels.squeeze(0).transpose(1, 2) - output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max() + + # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate + max_length = len_audio.max() if padding != "max_length" else max_target_length + output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length return output for split in vectorized_datasets: From 8198fd97f7af70185452e3f9ea02779dfec34e73 Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Wed, 5 Jun 2024 11:48:22 +0000 Subject: [PATCH 20/62] correct multi node --- training/run_parler_tts_training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index a0a73ed..56e988e 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -448,7 +448,7 @@ def apply_audio_decoder(batch): # (1, codebooks, seq_len) where seq_len=1 bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id - if accelerator.is_main_process: + if accelerator.is_local_main_process: tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) tmp_labels.save_to_disk( os.path.join(data_args.temporary_save_to_disk, split), @@ -457,8 +457,8 @@ def apply_audio_decoder(batch): del all_generated_labels accelerator.wait_for_everyone() - tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) with accelerator.local_main_process_first(): + tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) def postprocess_dataset(labels): @@ -539,6 +539,7 @@ def is_audio_in_length_range(length): data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1), ) + accelerator.wait_for_everyone() logger.info(f"Dataset saved at {data_args.save_to_disk}") audio_max_length = None @@ -550,7 +551,7 @@ def is_audio_in_length_range(length): num_proc=num_workers, input_columns=["target_length"], ) - audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1] + audio_max_length = max([len(l[0]) for l in max_sample["labels"]]) if training_args.group_by_length: # apply a simple heuristic to take into account audio and text lengths From 9b48d0a0322ce148da615abc5d073c2e6872c5bc Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 5 Jun 2024 14:41:57 +0200 Subject: [PATCH 21/62] make rope faster --- parler_tts/modeling_parler_tts.py | 163 ++++++++++++++---------------- 1 file changed, 78 insertions(+), 85 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index c85545d..c89615f 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -266,21 +266,21 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + # Ignore copy @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, device_type, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() + inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :] # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos, sin def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -323,6 +323,7 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, + rope_embeddings : bool = False, config: Optional[ParlerTTSDecoderConfig] = None, ): super().__init__() @@ -348,13 +349,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.rope_embeddings = config.rope_embeddings - if config.rope_embeddings: - self.rotary_emb = ParlerTTSRotaryEmbedding( - self.head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rope_embeddings = rope_embeddings def _shape_query(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -368,7 +363,8 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + cos: Optional[torch.LongTensor] = None, + sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -378,19 +374,18 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - + bsz, tgt_len = hidden_states.shape[:2] + # get query proj query_states = self.q_proj(hidden_states) * self.scaling query_states = self._shape_query(query_states, tgt_len, bsz) if self.rope_embeddings: - cos, sin = self.rotary_emb(query_states, position_ids) query_states = apply_rotary_pos_emb(query_states, cos, sin) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as + current_states = key_value_states if key_value_states is not None else hidden_states + + # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention @@ -400,23 +395,17 @@ def forward( # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - don't apply rope to the key states, since they already have positional embeddings applied - key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) else: - # self_attention - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) + + if not is_cross_attention: + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -528,7 +517,8 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + cos: Optional[torch.LongTensor] = None, + sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -540,18 +530,17 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() + bsz, q_len = hidden_states.shape[:2] # get query proj query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) if self.rope_embeddings: - cos, sin = self.rotary_emb(query_states.transpose(1,2), position_ids) query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as + current_states = key_value_states if key_value_states is not None else hidden_states + + # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention @@ -561,24 +550,17 @@ def forward( # reuse k,v, cross_attentions key_states = past_key_value[0].transpose(1,2) value_states = past_key_value[1].transpose(1,2) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states - key_states = torch.cat([past_key_value[0].transpose(1,2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1,2), value_states], dim=1) else: - # self_attention - key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states - + key_states = self.k_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) + + if not is_cross_attention: + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0].transpose(1,2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1,2), value_states], dim=1) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -739,7 +721,8 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + cos: Optional[torch.LongTensor] = None, + sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -763,19 +746,18 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + bsz, tgt_len = hidden_states.shape[:2] # get query proj query_states = self.q_proj(hidden_states) query_states = self._shape_query(query_states, tgt_len, bsz) if self.rope_embeddings: - cos, sin = self.rotary_emb(query_states, position_ids) query_states = apply_rotary_pos_emb(query_states, cos, sin) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as + current_states = key_value_states if key_value_states is not None else hidden_states + + # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention @@ -785,23 +767,17 @@ def forward( # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - don't apply rope to the key states, since they already have positional embeddings applied - key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) else: - # self_attention - key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) + + if not is_cross_attention: + # cached key states already have rope applied - only apply to new state + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -874,6 +850,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): is_decoder=True, is_causal=True, bias=False, + rope_embeddings=config.rope_embeddings, config=config, ) self.dropout = config.dropout @@ -893,6 +870,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + rope_embeddings=config.rope_embeddings, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -904,7 +882,8 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + cos: Optional[torch.LongTensor] = None, + sin: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, @@ -947,7 +926,8 @@ def forward( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, - position_ids=position_ids, + cos=cos, + sin=sin, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) @@ -967,7 +947,8 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - position_ids=position_ids, + cos=cos, + sin=sin, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, @@ -1276,6 +1257,12 @@ def __init__(self, config: ParlerTTSDecoderConfig): config.max_position_embeddings, config.hidden_size, ) + else: + self.rotary_emb = ParlerTTSRotaryEmbedding( + config.hidden_size // config.num_attention_heads, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(config.hidden_size) @@ -1368,6 +1355,7 @@ def forward( ) input_shape = inputs_embeds.size()[:-1] + cos, sin = None, None if not self.rope_embeddings: # embed positions @@ -1394,6 +1382,9 @@ def forward( # Some generation methods already pass only the last input ID if position_ids.shape[1] > input_shape[1]: position_ids = position_ids[:, -input_shape[1]:] + + cos, sin = self.rotary_emb(hidden_states.device.type, position_ids) + cos, sin = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1471,7 +1462,8 @@ def forward( decoder_layer.forward, hidden_states, attention_mask, - position_ids, + cos, + sin, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1484,7 +1476,8 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, - position_ids=position_ids, + cos=cos, + sin=sin, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), From 54b56d93ef71c47bf76a0baf3e64532a8a3a2f80 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 5 Jun 2024 17:50:08 +0200 Subject: [PATCH 22/62] fix encoder sdpa --- parler_tts/modeling_parler_tts.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index c89615f..281cb8e 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1267,6 +1267,10 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation + encoder_attn_implementation = config._attn_implementation + if config.cross_attention_implementation_strategy is not None: + encoder_attn_implementation = "sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager" + self.encoder_attn_implementation = encoder_attn_implementation self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1400,8 +1404,6 @@ def forward( inputs_embeds, past_key_values_length, ) - - else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -1409,9 +1411,9 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": + if self.encoder_attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: + elif self.encoder_attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -3200,9 +3202,6 @@ def generate( else: output_ids = outputs - # TODO: remove - return - # Apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) From 1da55fc858725607ecb377025a3dfa8153a9e8c7 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 5 Jun 2024 17:50:29 +0200 Subject: [PATCH 23/62] fix training with cross attention + with FAZ --- training/data.py | 5 ----- training/run_parler_tts_training.py | 10 +++++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/training/data.py b/training/data.py index 9603b35..8f737ad 100644 --- a/training/data.py +++ b/training/data.py @@ -97,11 +97,6 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> batch = {"labels": labels, **input_ids} - if self.audio_max_length is not None and self.padding == "max_length": - # if we do torch.compile, we need to also specify the attention_mask - decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype) - batch["decoder_attention_mask"] = decoder_attention_mask - prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = self.prompt_tokenizer.pad( prompt_input_ids, diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 56e988e..2031bd9 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -80,10 +80,13 @@ def main(): if training_args.dtype == "float16": mixed_precision = "fp16" + torch_dtype = torch.float16 elif training_args.dtype == "bfloat16": mixed_precision = "bf16" + torch_dtype = torch.bfloat16 else: mixed_precision = "no" + torch_dtype = torch.float32 if data_args.pad_to_max_length and ( data_args.max_duration_in_seconds is None @@ -295,13 +298,13 @@ def main(): ) # update pad token id and decoder_start_token_id + config.decoder.update({"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy if model_args.cross_attention_implementation_strategy is not None else None}) config.update( { "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id, "decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else config.decoder_start_token_id, - "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy if model_args.cross_attention_implementation_strategy is not None else None, } ) @@ -313,6 +316,7 @@ def main(): token=data_args.token, trust_remote_code=data_args.trust_remote_code, attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype if model_args.attn_implementation == "flash_attention_2" else None, ) # enable gradient checkpointing if necessary @@ -586,7 +590,7 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): input_ids = descriptions texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) - audios = [a.cpu().numpy() for a in audios] + audios = [a.float().cpu().numpy() for a in audios] clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) results["clap"] = clap_score @@ -823,7 +827,7 @@ def eval_step( def generate_step(batch): batch.pop("decoder_attention_mask", None) - eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval() + eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision == "fp32").eval() if training_args.torch_compile: eval_model = model._orig_mod From 8f6047afee2d17978ca02f63f631ae027bbd7851 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 6 Jun 2024 17:27:45 +0200 Subject: [PATCH 24/62] use fp32 as default model dtype + fix generation when using FA2 with autocast --- training/run_parler_tts_training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 2031bd9..344cfbe 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -316,7 +316,6 @@ def main(): token=data_args.token, trust_remote_code=data_args.trust_remote_code, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype if model_args.attn_implementation == "flash_attention_2" else None, ) # enable gradient checkpointing if necessary @@ -342,6 +341,7 @@ def main(): max_length = model.generation_config.max_length num_codebooks = model.decoder.config.num_codebooks bandwidth = model_args.bandwidth + attn_implementation = model_args.attn_implementation # Freeze Encoders model.freeze_encoders(model_args.freeze_text_encoder) @@ -831,7 +831,9 @@ def generate_step(batch): if training_args.torch_compile: eval_model = model._orig_mod - output_audios = eval_model.generate(**batch, **gen_kwargs) + # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision. + with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))): + output_audios = eval_model.generate(**batch, **gen_kwargs) output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0) return output_audios @@ -1040,5 +1042,4 @@ def generate_step(batch): if __name__ == "__main__": - set_start_method("spawn") main() From d056ca52b7954e71e1aa6011383de6c7c06e7dda Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 6 Jun 2024 17:28:36 +0200 Subject: [PATCH 25/62] remove redundant passes in generate + clean and fix attentions --- parler_tts/modeling_parler_tts.py | 129 ++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 281cb8e..8b4e4cf 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -561,7 +561,7 @@ def forward( if past_key_value is not None: key_states = torch.cat([past_key_value[0].transpose(1,2), key_states], dim=1) value_states = torch.cat([past_key_value[1].transpose(1,2), value_states], dim=1) - + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -572,10 +572,6 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states.transpose(1,2), value_states.transpose(1,2)) - kv_seq_len = key_states.shape[-3] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. @@ -793,8 +789,6 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = self._shape_query(query_states, tgt_len, bsz) - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. @@ -1088,6 +1082,7 @@ def _init_weights(self, module): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape @@ -1331,7 +1326,10 @@ def forward( if prompt_hidden_states is not None: inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) - # As it is, the masked ids from the prompt will still count in the positions embeddings + # NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings + # NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask + # i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting + # `prompt_attention_mask=None` if prompt_attention_mask is not None and attention_mask is not None: attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) elif prompt_attention_mask is not None: @@ -1347,6 +1345,9 @@ def forward( dim=1, ) else: + # In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch + # to be able to prepend the prompt attention mask. + # Since we generate token per token, we can recompute the generated length from the information we have. generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 attention_mask = torch.cat( [ @@ -2548,6 +2549,9 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + if prompt_hidden_states is None: + if prompt_input_ids is not None: + prompt_hidden_states = self.embed_prompts(prompt_input_ids) if encoder_outputs is None: encoder_outputs = self.text_encoder( @@ -2559,42 +2563,42 @@ def forward( return_dict=return_dict, **kwargs_text_encoder, ) - elif isinstance(encoder_outputs, tuple): - encoder_outputs = BaseModelOutput(*encoder_outputs) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if ( - self.text_encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + encoder_hidden_states = encoder_outputs[0] - if attention_mask is not None: - encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] + # optionally project encoder_hidden_states + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - if prompt_hidden_states is None: - if prompt_input_ids is not None: - prompt_hidden_states = self.embed_prompts(prompt_input_ids) + if attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] - if prompt_hidden_states is not None and self.prompt_cross_attention: - # add sinusoidal positional embedding - positions = self.embed_positions(prompt_hidden_states, 0) - prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) + if prompt_hidden_states is not None and self.prompt_cross_attention: + # add sinusoidal positional embedding + positions = self.embed_positions(prompt_hidden_states, 0) + prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) + + if prompt_attention_mask is not None and attention_mask is None: + attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) + elif attention_mask is not None and prompt_attention_mask is None: + prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) + + # concatenate text description states with prompt description states + encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) + if prompt_attention_mask is not None: + attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) + + prompt_hidden_states = None + prompt_attention_mask = None - if prompt_attention_mask is not None and attention_mask is None: - attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) - elif attention_mask is not None and prompt_attention_mask is None: - prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) + encoder_outputs["last_hidden_state"] = encoder_hidden_states - # concatenate text description states with prompt description states - encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) - if prompt_attention_mask is not None: - attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) - prompt_hidden_states = None - prompt_attention_mask = None + encoder_hidden_states = encoder_outputs.last_hidden_state if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( @@ -2706,8 +2710,9 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - # we only want to use prompt signal in the 1st generation step but keeping the attention mask - prompt_hidden_states = prompt_hidden_states if self.prompt_cross_attention else None + # if prompt_cross_attention, + # we only want to use prompt signal in the 1st generation step + prompt_hidden_states = None return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2816,12 +2821,52 @@ def _prepare_text_encoder_kwargs_for_generation( [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 ) - model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) + # we optionnally project last_hidden_state to avoid recomputing every time + encoder_hidden_states = last_hidden_state + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if model_kwargs["attention_mask"] is not None: + encoder_hidden_states = encoder_hidden_states * model_kwargs["attention_mask"][..., None] + + model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=encoder_hidden_states) return model_kwargs def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): - model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids) + prompt_hidden_states = self.embed_prompts(prompt_input_ids) + + if self.prompt_cross_attention: + # add sinusoidal positional embedding + positions = self.embed_positions(prompt_hidden_states, 0) + prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) + + attention_mask = model_kwargs.get("attention_mask", None) + prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None) + encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state + + if prompt_attention_mask is not None and attention_mask is None: + attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) + elif attention_mask is not None and prompt_attention_mask is None: + prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) + + # concatenate text description states with prompt description states + encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) + if prompt_attention_mask is not None: + attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) + + model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states + model_kwargs["attention_mask"] = attention_mask + + # in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore. + model_kwargs["prompt_hidden_states"] = None + model_kwargs["prompt_attention_mask"] = None + else: + model_kwargs["prompt_hidden_states"] = prompt_hidden_states + # we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly return model_kwargs def _prepare_audio_encoder_kwargs_for_generation( From 7dfbbca347e3939a8fdb65eaaae6d3c9ca155dfa Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Thu, 6 Jun 2024 15:29:36 +0000 Subject: [PATCH 26/62] fix edge case in WER evaluation when longform generation --- training/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/eval.py b/training/eval.py index 67e01e3..54d33e1 100644 --- a/training/eval.py +++ b/training/eval.py @@ -23,7 +23,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): metric = evaluate.load("wer") - asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) + asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0) return_language = None if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): From 15edf7c66ad1a83302d8e4975e2082e34df87305 Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Fri, 7 Jun 2024 07:58:22 +0000 Subject: [PATCH 27/62] better multi-node mapping and saving / add eval dataloader num workers --- training/arguments.py | 4 ++++ training/run_parler_tts_training.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/training/arguments.py b/training/arguments.py index ff63bb7..1e69702 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -323,3 +323,7 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): default=8, metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")}, ) + eval_dataloader_num_workers: Optional[int] = field( + default=0, + metadata={"help": ("Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process.")}, + ) \ No newline at end of file diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 344cfbe..8c2ef67 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -222,7 +222,8 @@ def main(): # assume that the dataset has been saved to `save_to_disk` if the latter is not empty dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0 if dataset_was_precomputed: - vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk) + with accelerator.local_main_process_first(): + vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk) else: raw_datasets = DatasetDict() @@ -285,9 +286,10 @@ def main(): ) if data_args.max_eval_samples is not None: - raw_datasets["eval"] = ( - raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) - ) + with accelerator.local_main_process_first(): + raw_datasets["eval"] = ( + raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) + ) # 3. Next, let's load the config. config = ParlerTTSConfig.from_pretrained( @@ -743,7 +745,8 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): steps_trained_progress_bar.update(cur_step) for epoch in range(0, epochs_trained): - vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) + with accelerator.local_main_process_first(): + vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) if training_args.max_steps < 0: # we know exactly the number of steps per epoch, so can skip through the required number of batches @@ -753,7 +756,8 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): # So we just shuffle the dataset one extra time and start from a fresh epoch # This is "good enough" for our purposes but not fully correct resume_step = None - vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) + with accelerator.local_main_process_first(): + vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) else: resume_step = None @@ -838,7 +842,8 @@ def generate_step(batch): return output_audios for epoch in range(epochs_trained, num_epochs): - vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) + with accelerator.local_main_process_first(): + vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) sampler = None if training_args.group_by_length: sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"]) @@ -933,7 +938,7 @@ def generate_step(batch): collate_fn=data_collator, batch_size=per_device_eval_batch_size, drop_last=False, - num_workers=training_args.dataloader_pin_memory, + num_workers=training_args.eval_dataloader_num_workers, pin_memory=training_args.dataloader_pin_memory, ) validation_dataloader = accelerator.prepare(validation_dataloader) From ef4065428aa045cb5b059884f54da5f5e06bceb1 Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Fri, 7 Jun 2024 08:01:44 +0000 Subject: [PATCH 28/62] remove old benchmarks --- helpers/benchmark/dataset.py | 7 --- helpers/benchmark/parler_eager_attention.py | 60 --------------------- helpers/benchmark/parler_flash_attention.py | 60 --------------------- helpers/benchmark/parler_sdpa_attention.py | 60 --------------------- 4 files changed, 187 deletions(-) delete mode 100644 helpers/benchmark/dataset.py delete mode 100644 helpers/benchmark/parler_eager_attention.py delete mode 100644 helpers/benchmark/parler_flash_attention.py delete mode 100644 helpers/benchmark/parler_sdpa_attention.py diff --git a/helpers/benchmark/dataset.py b/helpers/benchmark/dataset.py deleted file mode 100644 index f923b6b..0000000 --- a/helpers/benchmark/dataset.py +++ /dev/null @@ -1,7 +0,0 @@ -from datasets import load_dataset - -dataset = load_dataset("parler-tts/libritts_r_tags_tagged_10k_generated", 'clean') - -PROMPTS = dataset['test.clean']['text'] -DESCRIPTIONS = dataset['test.clean']['text_description'] - diff --git a/helpers/benchmark/parler_eager_attention.py b/helpers/benchmark/parler_eager_attention.py deleted file mode 100644 index be9df11..0000000 --- a/helpers/benchmark/parler_eager_attention.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed -from tqdm import tqdm -from dataset import PROMPTS, DESCRIPTIONS -import time - -model = ParlerTTSForConditionalGeneration.from_pretrained( - "parler-tts/parler-tts-mini-expresso", - attn_implementation="eager", - torch_dtype=torch.float16 -).to("cuda:0") - - -for i in range(3): - print(f"Wramming up decoder") - z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) - model.audio_encoder.model.decode(z) - - -tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") - -def generate_speech(prompt, description): - input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") - prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") - - generation_config = model.generation_config - - # Generate first second - generation_config.max_length = 86 # default 2580. WTF - - _ = model.generate(input_ids=input_ids, - prompt_input_ids=prompt_input_ids, - generation_config=generation_config, - use_cache=True, - past_key_values = None, - ) - - -if __name__ == "__main__": - NUM_SAMPLE = 20 - - latencies = [] - - for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): - prompt = PROMPTS[i] - description = DESCRIPTIONS[i] - - start = time.perf_counter() - - _ = generate_speech(prompt, description) - - latencies.append(time.perf_counter() - start) - - - print(f"AVG latency = {sum(latencies) / len(latencies)}") - - - - diff --git a/helpers/benchmark/parler_flash_attention.py b/helpers/benchmark/parler_flash_attention.py deleted file mode 100644 index b0a9e9f..0000000 --- a/helpers/benchmark/parler_flash_attention.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed -from tqdm import tqdm -from dataset import PROMPTS, DESCRIPTIONS -import time - -model = ParlerTTSForConditionalGeneration.from_pretrained( - "parler-tts/parler-tts-mini-expresso", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16 -).to("cuda:0") - - -for i in range(3): - print(f"Wramming up decoder") - z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) - model.audio_encoder.model.decode(z) - - -tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") - -def generate_speech(prompt, description): - input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") - prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") - - generation_config = model.generation_config - - # Generate first second - generation_config.max_length = 86 # default 2580. WTF - - _ = model.generate(input_ids=input_ids, - prompt_input_ids=prompt_input_ids, - generation_config=generation_config, - use_cache=True, - past_key_values = None, - ) - - -if __name__ == "__main__": - NUM_SAMPLE = 20 - - latencies = [] - - for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): - prompt = PROMPTS[i] - description = DESCRIPTIONS[i] - - start = time.perf_counter() - - _ = generate_speech(prompt, description) - - latencies.append(time.perf_counter() - start) - - - print(f"AVG latency = {sum(latencies) / len(latencies)}") - - - - diff --git a/helpers/benchmark/parler_sdpa_attention.py b/helpers/benchmark/parler_sdpa_attention.py deleted file mode 100644 index 73470f5..0000000 --- a/helpers/benchmark/parler_sdpa_attention.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed -from tqdm import tqdm -from dataset import PROMPTS, DESCRIPTIONS -import time - -model = ParlerTTSForConditionalGeneration.from_pretrained( - "parler-tts/parler-tts-mini-expresso", - attn_implementation="sdpa", - torch_dtype=torch.float16 -).to("cuda:0") - - -for i in range(3): - print(f"Wramming up decoder") - z = torch.empty(1, 1024, 8).uniform_(-10,10).type(torch.FloatTensor).to(model.device).to(model.dtype) - model.audio_encoder.model.decode(z) - - -tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") - -def generate_speech(prompt, description): - input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0") - prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0") - - generation_config = model.generation_config - - # Generate first second - generation_config.max_length = 86 # default 2580. WTF - - _ = model.generate(input_ids=input_ids, - prompt_input_ids=prompt_input_ids, - generation_config=generation_config, - use_cache=True, - past_key_values = None, - ) - - -if __name__ == "__main__": - NUM_SAMPLE = 20 - - latencies = [] - - for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))): - prompt = PROMPTS[i] - description = DESCRIPTIONS[i] - - start = time.perf_counter() - - _ = generate_speech(prompt, description) - - latencies.append(time.perf_counter() - start) - - - print(f"AVG latency = {sum(latencies) / len(latencies)}") - - - - From 954d8c5d3b54b681c63e08526aee90c6afd201eb Mon Sep 17 00:00:00 2001 From: "yoach@huggingface.co" Date: Fri, 21 Jun 2024 09:22:02 +0000 Subject: [PATCH 29/62] faster audio encoding + checkpointing + fix generation step --- parler_tts/modeling_parler_tts.py | 5 +- training/arguments.py | 4 + training/data.py | 4 +- training/eval.py | 12 +- training/run_parler_tts_training.py | 326 ++++++++++++++++------------ training/utils.py | 52 ++++- 6 files changed, 256 insertions(+), 147 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 8b4e4cf..4f63037 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -577,9 +577,8 @@ def forward( # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: + + if query_states.dtype == torch.float32 or value_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized diff --git a/training/arguments.py b/training/arguments.py index 1e69702..79c81e3 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -302,6 +302,10 @@ class DataTrainingArguments: }, ) temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."}) + save_codec_steps: Optional[int] = field( + default=500, + metadata={"help": "Temporarily save the audio labels every `save_steps`."}, + ) pad_to_multiple_of: Optional[int] = field( default=2, metadata={"help": ("Pad to multiple of for tokenizers.")}, diff --git a/training/data.py b/training/data.py index 8f737ad..2a293e6 100644 --- a/training/data.py +++ b/training/data.py @@ -239,7 +239,7 @@ def load_multiple_datasets( # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") - if dataset_dict["name"] != "parler-tts/mls_eng_10k": + if dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}: if id_column_name is not None and id_column_name not in dataset.column_names: raise ValueError( f"id_column_name={id_column_name} but has not been found in the dataset columns" @@ -269,7 +269,7 @@ def load_multiple_datasets( dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) - if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": + if id_column_name is not None and dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}: if ( len( dataset.filter( diff --git a/training/eval.py b/training/eval.py index 54d33e1..1116dd0 100644 --- a/training/eval.py +++ b/training/eval.py @@ -1,6 +1,7 @@ import torch import evaluate from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast +from accelerate.utils.memory import release_memory def clap_similarity(clap_model_name_or_path, texts, audios, device): @@ -14,11 +15,13 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ) audio_features = clap.get_audio_features(clap_inputs["input_features"]) - cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) + cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean() + + cosine_sim = cosine_sim.to("cpu") clap.to("cpu") - clap_inputs.to("cpu") - return cosine_sim.mean().to("cpu") + clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features) + return cosine_sim def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): @@ -55,5 +58,6 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s normalized_references.append(norm_ref) word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) - + asr_pipeline.model.to("cpu") + asr_pipeline = release_memory(asr_pipeline) return word_error, [t["text"] for t in transcriptions] diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 8c2ef67..d474112 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -42,7 +42,7 @@ from transformers.utils import send_example_telemetry -from accelerate import Accelerator +from accelerate import Accelerator, skip_first_batches from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils.memory import release_memory @@ -52,12 +52,11 @@ build_delay_pattern_mask, ) -from training.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric +from training.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric, load_all_codec_checkpoints, save_codec_checkpoint, get_last_codec_checkpoint_step from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from training.eval import clap_similarity, wer - logger = logging.getLogger(__name__) @@ -425,6 +424,37 @@ def apply_audio_decoder(batch): output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length return output + # (1, codebooks, seq_len) where seq_len=1 + bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id + + def postprocess_dataset(labels): + # (1, codebooks, seq_len) + labels = torch.tensor(labels).unsqueeze(0) + # add bos + labels = torch.cat([bos_labels, labels], dim=-1) + + labels, delay_pattern_mask = build_delay_pattern_mask( + labels, + bos_token_id=audio_encoder_bos_token_id, + pad_token_id=audio_encoder_eos_token_id, + max_length=labels.shape[-1] + num_codebooks, + num_codebooks=num_codebooks, + ) + + # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask + # to take care of EOS + # we want labels to look like this: + # - [B, a, b, E, E, E, E] + # - [B, B, c, d, E, E, E] + # - [B, B, B, e, f, E, E] + # - [B, B, B, B, g, h, E] + labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask) + + # the first timestamp is associated to a row full of BOS, let's get rid of it + # we also remove the last timestampts (full of PAD) + output = {"labels": labels[:, 1:]} + return output + for split in vectorized_datasets: data_loader = DataLoader( raw_datasets[split], @@ -434,75 +464,69 @@ def apply_audio_decoder(batch): pin_memory=True, ) data_loader = accelerator.prepare(data_loader) + total_inference_steps = len(data_loader) + + start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split)) + accelerator.wait_for_everyone() + if start_step > 0: + logger.info(f"Resuming {split} from step {start_step}") + # efficiently skip the first n batches + start_step += 1 + data_loader = skip_first_batches(data_loader, start_step) all_generated_labels = [] all_lens = [] - for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): - generate_labels = apply_audio_decoder(batch) - generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) - generate_labels = accelerator.gather_for_metrics(generate_labels) - - if accelerator.is_local_main_process: - lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16) - rat = generate_labels["ratio"].cpu().squeeze() - lens = generate_labels["len_audio"].cpu().squeeze() - lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)] - - all_generated_labels.extend(lab) - all_lens.extend(lens) - - # (1, codebooks, seq_len) where seq_len=1 - bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id + if start_step < total_inference_steps: + for (i, batch) in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)): + cur_step = start_step + i + generate_labels = apply_audio_decoder(batch) + generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) + generate_labels = accelerator.gather_for_metrics(generate_labels) - if accelerator.is_local_main_process: + if accelerator.is_main_process: + lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16) + rat = generate_labels["ratio"].cpu().squeeze(1) + lens = generate_labels["len_audio"].cpu().squeeze(1) + lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)] + + all_generated_labels.extend(lab) + all_lens.extend(lens) + + if ((cur_step+1) % data_args.save_codec_steps == 0) or (cur_step == total_inference_steps - 1): + tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) + tmp_labels = tmp_labels.map( + postprocess_dataset, + num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. + input_columns=["labels"], + desc="Postprocessing labeling", + ) + save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step) + all_generated_labels = [] + all_lens = [] + + accelerator.wait_for_everyone() + + if accelerator.is_main_process and len(all_generated_labels) > 0: tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) - tmp_labels.save_to_disk( - os.path.join(data_args.temporary_save_to_disk, split), - num_proc=1 if split == "eval" else data_args.preprocessing_num_workers, + tmp_labels = tmp_labels.map( + postprocess_dataset, + num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. + input_columns=["labels"], + desc="Postprocessing labeling", ) + save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step) + all_generated_labels = [] + all_lens = [] + accelerator.wait_for_everyone() + del all_generated_labels accelerator.wait_for_everyone() with accelerator.local_main_process_first(): - tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) + tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(range(len(vectorized_datasets[split]))) + logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}") vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) - def postprocess_dataset(labels): - # (1, codebooks, seq_len) - labels = torch.tensor(labels).unsqueeze(0) - # add bos - labels = torch.cat([bos_labels, labels], dim=-1) - - labels, delay_pattern_mask = build_delay_pattern_mask( - labels, - bos_token_id=audio_encoder_bos_token_id, - pad_token_id=audio_encoder_eos_token_id, - max_length=labels.shape[-1] + num_codebooks, - num_codebooks=num_codebooks, - ) - - # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask - # to take care of EOS - # we want labels to look like this: - # - [B, a, b, E, E, E, E] - # - [B, B, c, d, E, E, E] - # - [B, B, B, e, f, E, E] - # - [B, B, B, B, g, h, E] - labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask) - - # the first timestamp is associated to a row full of BOS, let's get rid of it - # we also remove the last timestampts (full of PAD) - output = {"labels": labels[:, 1:]} - return output - - with accelerator.local_main_process_first(): - vectorized_datasets[split] = vectorized_datasets[split].map( - postprocess_dataset, - num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. - input_columns=["labels"], - desc="Postprocessing labeling", - ) - accelerator.free_memory() del generate_labels, all_lens @@ -521,23 +545,23 @@ def is_audio_in_length_range(length): input_columns=["target_length"], ) - if description_column_name is not None and data_args.max_description_token_length is not None: - with accelerator.local_main_process_first(): - # filter description that is shorter than max_text_length - vectorized_datasets = vectorized_datasets.filter( - lambda x: len(x) < data_args.max_description_token_length, - num_proc=num_workers, - input_columns=["input_ids"], - ) + if description_column_name is not None and data_args.max_description_token_length is not None: + with accelerator.local_main_process_first(): + # filter description that is shorter than max_text_length + vectorized_datasets = vectorized_datasets.filter( + lambda x: len(x) < data_args.max_description_token_length, + num_proc=num_workers, + input_columns=["input_ids"], + ) - if data_args.max_prompt_token_length is not None: - with accelerator.local_main_process_first(): - # filter description that is shorter than max_text_length - vectorized_datasets = vectorized_datasets.filter( - lambda x: len(x) < data_args.max_prompt_token_length, - num_proc=num_workers, - input_columns=["prompt_input_ids"], - ) + if data_args.max_prompt_token_length is not None: + with accelerator.local_main_process_first(): + # filter description that is shorter than max_text_length + vectorized_datasets = vectorized_datasets.filter( + lambda x: len(x) < data_args.max_prompt_token_length, + num_proc=num_workers, + input_columns=["prompt_input_ids"], + ) if data_args.save_to_disk is not None and not dataset_was_precomputed: if accelerator.is_main_process: @@ -606,7 +630,6 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): sampling_rate, ) results["wer"] = word_error - return results, texts, prompts, audios, transcriptions # Define Training Schedule @@ -711,24 +734,24 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): # Now save everything to be able to create a single processor later # make sure all processes wait until data is saved - with accelerator.main_process_first(): - # only the main process saves them - if accelerator.is_main_process: - # save feature extractor, tokenizer and config - if ( - model_args.prompt_tokenizer_name is None - and model_args.description_tokenizer_name - or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name) - ): - prompt_tokenizer.save_pretrained(training_args.output_dir) - else: - logger.warning( - f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer." - ) - prompt_tokenizer.save_pretrained(training_args.output_dir) + # only the main process saves them + if accelerator.is_main_process: + # save feature extractor, tokenizer and config + if ( + model_args.prompt_tokenizer_name is None + and model_args.description_tokenizer_name + or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name) + ): + prompt_tokenizer.save_pretrained(training_args.output_dir) + else: + logger.warning( + f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer." + ) + prompt_tokenizer.save_pretrained(training_args.output_dir) - feature_extractor.save_pretrained(training_args.output_dir) - config.save_pretrained(training_args.output_dir) + feature_extractor.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + accelerator.wait_for_everyone() if checkpoint is not None: accelerator.load_state(checkpoint) @@ -777,8 +800,6 @@ def train_step( accelerator, autocast_kwargs, ): - model.train() - if mixed_precision == "fp16": # fp16 doesn't work with T5-like models with accelerator.autocast(autocast_handler=autocast_kwargs): @@ -790,6 +811,18 @@ def train_step( encoder_outputs = model.module.text_encoder( input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) ) + # we optionnally project last_hidden_state to avoid recomputing every time + encoder_hidden_states = encoder_outputs.last_hidden_state + if ( + config.text_encoder.hidden_size != config.decoder.hidden_size + and config.decoder.cross_attention_hidden_size is None + ): + encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states) if training_args.parallel_mode.value != "distributed" else model.module.enc_to_dec_proj(encoder_hidden_states) + + if batch.get("attention_mask", None) is not None: + encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None] + + encoder_outputs.last_hidden_state = encoder_hidden_states batch["encoder_outputs"] = encoder_outputs outputs = model(**batch) @@ -806,22 +839,33 @@ def eval_step( autocast_kwargs, ): eval_model = model if not training_args.torch_compile else model._orig_mod - eval_model.eval() if mixed_precision == "fp16": # fp16 doesn't work with T5-like models with accelerator.autocast(autocast_handler=autocast_kwargs): - with torch.no_grad(): - if training_args.parallel_mode.value != "distributed" or training_args.torch_compile: - encoder_outputs = eval_model.text_encoder( - input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) - ) - else: - encoder_outputs = eval_model.module.text_encoder( - input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) - ) + if training_args.parallel_mode.value != "distributed": + encoder_outputs = model.text_encoder( + input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) + ) + else: + encoder_outputs = model.module.text_encoder( + input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) + ) + # we optionnally project last_hidden_state to avoid recomputing every time + encoder_hidden_states = encoder_outputs.last_hidden_state + if ( + config.text_encoder.hidden_size != config.decoder.hidden_size + and config.decoder.cross_attention_hidden_size is None + ): + encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states) if training_args.parallel_mode.value != "distributed" else model.module.enc_to_dec_proj(encoder_hidden_states) + + if batch.get("attention_mask", None) is not None: + encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None] + + encoder_outputs.last_hidden_state = encoder_hidden_states batch["encoder_outputs"] = encoder_outputs + with torch.no_grad(): outputs = eval_model(**batch) # CE (data) loss @@ -829,18 +873,21 @@ def eval_step( metrics = {"loss": ce_loss} return metrics - def generate_step(batch): + def generate_step(batch, accelerator): batch.pop("decoder_attention_mask", None) - eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision == "fp32").eval() + eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True) # (attn_implementation!="flash_attention_2")) if training_args.torch_compile: + # if the model is compiled, we use the original model bc compile is not compatible with .generate eval_model = model._orig_mod # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision. - with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))): - output_audios = eval_model.generate(**batch, **gen_kwargs) + # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))): + output_audios = eval_model.generate(**batch, **gen_kwargs) output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0) return output_audios + model.train() + for epoch in range(epochs_trained, num_epochs): with accelerator.local_main_process_first(): vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) @@ -861,8 +908,10 @@ def generate_step(batch): if resume_step is not None: # Skip the first N batches in the dataloader when resuming from a checkpoint + logger.info(f" Skip first {resume_step} batches") train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) resume_step = None + accelerator.wait_for_everyone() for batch in train_dataloader: with accelerator.accumulate(model): @@ -924,6 +973,7 @@ def generate_step(batch): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): train_time += time.time() - train_start # ======================== Evaluating ============================== + model.eval() eval_metrics = [] eval_preds = [] eval_descriptions = [] @@ -971,7 +1021,7 @@ def generate_step(batch): position=2, disable=not accelerator.is_local_main_process, ): - generated_audios = generate_step(batch) + generated_audios = generate_step(batch, accelerator) # Gather all predictions and targets generated_audios, input_ids, prompts = accelerator.pad_across_processes( (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0 @@ -986,35 +1036,38 @@ def generate_step(batch): eval_time = time.time() - eval_start # normalize eval metrics eval_metrics = { - key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])) + key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])).to("cpu") for key in eval_metrics[0] } # compute metrics metrics_desc = "" if training_args.predict_with_generate: - metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics( - eval_preds, eval_descriptions, eval_prompts, accelerator.device - ) - eval_metrics.update(metric_values) - metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()]) - if "wandb" in training_args.report_to: - log_pred( - accelerator, - pred_descriptions, - pred_prompts, - transcriptions, - audios, - sampling_rate=sampling_rate, - step=cur_step, - prefix="eval", + if accelerator.is_local_main_process: + metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics( + eval_preds, eval_descriptions, eval_prompts, accelerator.device ) + eval_metrics.update(metric_values) + metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()]) + if "wandb" in training_args.report_to: + log_pred( + accelerator, + pred_descriptions, + pred_prompts, + transcriptions, + audios, + sampling_rate=sampling_rate, + step=cur_step, + prefix="eval", + ) + accelerator.wait_for_everyone() # Print metrics and update progress bar - steps_trained_progress_bar.write( - f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |" - f" {metrics_desc})" - ) + if accelerator.is_local_main_process: + steps_trained_progress_bar.write( + f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |" + f" {metrics_desc})" + ) log_metric( accelerator, @@ -1026,11 +1079,12 @@ def generate_step(batch): ) # release eval batch and relax metrics - eval_metrics = [] - eval_preds = [] - eval_descriptions = [] - eval_prompts = [] - batch = release_memory(batch) + eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric) + if training_args.predict_with_generate: + generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts) + + # train mode + model.train() # flush the train metrics train_start = time.time() diff --git a/training/utils.py b/training/utils.py index 2328575..8dc2e0a 100644 --- a/training/utils.py +++ b/training/utils.py @@ -7,14 +7,15 @@ import torch from wandb import Audio - +from datasets import load_from_disk, concatenate_datasets def list_field(default=None, metadata=None): return field(default_factory=lambda: default, metadata=metadata) _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") - +CHECKPOINT_CODEC_PREFIX = "checkpoint" +_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$") def get_last_checkpoint(folder): content = os.listdir(folder) @@ -60,6 +61,53 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix shutil.rmtree(checkpoint, ignore_errors=True) +def save_codec_checkpoint(output_dir, dataset, step): + checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}" + output_path = os.path.join(output_dir, checkpoint_path) + dataset.save_to_disk(output_path) + +def load_codec_checkpoint(checkpoint_path): + dataset = load_from_disk(checkpoint_path) + return dataset + +def sorted_codec_checkpoints(output_dir=None) -> List[str]: + """Helper function to sort saved checkpoints from oldest to newest.""" + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")] + + for path in glob_checkpoints: + regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + return checkpoints_sorted + +def load_all_codec_checkpoints(output_dir=None) -> List[str]: + """Helper function to load and concat all checkpoints.""" + checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir) + datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted] + datasets = concatenate_datasets(datasets, axis=0) + return datasets + + +def get_last_codec_checkpoint_step(folder) -> int: + if not os.path.exists(folder) or not os.path.isdir(folder): + os.makedirs(folder, exist_ok=True) + return 0 + content = os.listdir(folder) + checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None] + if len(checkpoints) == 0: + return 0 + last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))) + # Find num steps saved state string pattern + pattern = r"checkpoint-(\d+)" + match = re.search(pattern, last_checkpoint) + cur_step = int(match.group(1)) + return cur_step + def log_metric( accelerator, metrics: Dict, From f9c36ac3888e5e32617f6d1f4a3f8511698f17cd Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 8 Jul 2024 16:39:58 +0200 Subject: [PATCH 30/62] unpin trfms --- parler_tts/modeling_parler_tts.py | 41 +++++++++++++++++++++++-------- setup.py | 2 +- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 4f63037..528b397 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1940,6 +1940,8 @@ def generate( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache @@ -2020,6 +2022,7 @@ def generate( encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -2035,7 +2038,7 @@ def generate( ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -2047,7 +2050,7 @@ def generate( elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -2601,7 +2604,7 @@ def forward( if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id + labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id ).transpose(1, 2) elif decoder_input_ids is None and decoder_inputs_embeds is None: @@ -2916,7 +2919,7 @@ def _prepare_audio_encoder_kwargs_for_generation( return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2) + return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id).transpose(1, 2) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( @@ -2953,6 +2956,24 @@ def _maybe_initialize_input_ids_for_generation( break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + ) -> int: + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + def freeze_encoders(self, freeze_text_encoder=True): if freeze_text_encoder: for param in self.text_encoder.parameters(): @@ -3074,6 +3095,8 @@ def generate( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache @@ -3089,10 +3112,7 @@ def generate( if "encoder_outputs" not in model_kwargs: # encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_text_encoder_kwargs_for_generation( - inputs_tensor, - model_kwargs, - model_input_name, - generation_config, + inputs_tensor, model_kwargs, model_input_name, generation_config ) if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: @@ -3186,6 +3206,7 @@ def generate( encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, + device=input_ids.device, ) # 10. prepare stopping criteria @@ -3201,7 +3222,7 @@ def generate( ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -3213,7 +3234,7 @@ def generate( elif is_sample_gen_mode: # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( diff --git a/setup.py b/setup.py index cef187d..cea2faf 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ _deps = [ - "transformers>=4.39.0,<4.41.0", + "transformers>=4.39.0", "torch", "sentencepiece", "descript-audio-codec", From e52a8f051b2a3a978132689109fcdef3b7b5e6ac Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 8 Jul 2024 16:58:25 +0200 Subject: [PATCH 31/62] remove CFG --- parler_tts/modeling_parler_tts.py | 75 +++++-------------------------- 1 file changed, 10 insertions(+), 65 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 528b397..3e6cd0d 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -172,14 +172,10 @@ class ParlerTTSUnconditionalInput(ModelOutput): attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. - guidance_scale (`float`, *optional*): - Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted - from the prompts) and the unconditional logits (predicted without prompts). """ encoder_outputs: Tuple[torch.FloatTensor] = None attention_mask: torch.LongTensor = None - guidance_scale: float = None # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right @@ -1744,7 +1740,6 @@ def prepare_inputs_for_generation( past_key_values=None, use_cache=True, delay_pattern_mask=None, - guidance_scale=None, **kwargs, ): if delay_pattern_mask is None: @@ -1758,23 +1753,6 @@ def prepare_inputs_for_generation( # apply the delay pattern mask input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) - if guidance_scale is not None and guidance_scale > 1: - # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these - # before sampling) - input_ids = input_ids.repeat((2, 1)) - if attention_mask is not None: - attention_mask = attention_mask.repeat((2, 1)) - - if prompt_hidden_states is not None: - prompt_hidden_states = torch.concatenate( - [prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0 - ) - - if prompt_attention_mask is not None: - prompt_attention_mask = torch.concatenate( - [prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0 - ) - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1945,7 +1923,6 @@ def generate( # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache - model_kwargs["guidance_scale"] = generation_config.guidance_scale requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: @@ -2010,12 +1987,7 @@ def generate( and generation_config.do_sample is True ) - # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) - generation_config.guidance_scale = None - - # 9. prepare distribution pre_processing samplers + # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -2025,7 +1997,7 @@ def generate( device=input_ids.device, ) - # 10. prepare stopping criteria + # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) @@ -2037,7 +2009,7 @@ def generate( f"but is {generation_config.num_return_sequences}." ) - # 11. run greedy search + # 10. run greedy search outputs = self._sample( input_ids, logits_processor=logits_processor, @@ -2049,7 +2021,7 @@ def generate( ) elif is_sample_gen_mode: - # 11. prepare logits warper + # 10. prepare logits warper logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch @@ -2059,7 +2031,7 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, @@ -2675,7 +2647,6 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, decoder_delay_pattern_mask=None, - guidance_scale=None, **kwargs, ): if decoder_delay_pattern_mask is None: @@ -2689,17 +2660,6 @@ def prepare_inputs_for_generation( # apply the delay pattern mask decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) - if guidance_scale is not None and guidance_scale > 1: - # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these - # before sampling) - decoder_input_ids = decoder_input_ids.repeat((2, 1)) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) - if prompt_hidden_states is not None: - prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1)) - if prompt_attention_mask is not None: - prompt_attention_mask = prompt_attention_mask.repeat((2, 1)) - if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -2807,7 +2767,6 @@ def _prepare_text_encoder_kwargs_for_generation( } encoder_kwargs["output_attentions"] = generation_config.output_attentions encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states - guidance_scale = generation_config.guidance_scale # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name @@ -2815,14 +2774,6 @@ def _prepare_text_encoder_kwargs_for_generation( encoder_kwargs[model_input_name] = inputs_tensor last_hidden_state = encoder(**encoder_kwargs).last_hidden_state - # for classifier free guidance we need to add a 'null' input to our encoder hidden states - if guidance_scale is not None and guidance_scale > 1: - last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = torch.concatenate( - [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 - ) - # we optionnally project last_hidden_state to avoid recomputing every time encoder_hidden_states = last_hidden_state if ( @@ -3100,7 +3051,6 @@ def generate( # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache - model_kwargs["guidance_scale"] = generation_config.guidance_scale requires_attention_mask = "encoder_outputs" not in model_kwargs @@ -3194,12 +3144,7 @@ def generate( and generation_config.do_sample is True ) - # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) - generation_config.guidance_scale = None - - # 9. prepare distribution pre_processing samplers + # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -3209,7 +3154,7 @@ def generate( device=input_ids.device, ) - # 10. prepare stopping criteria + # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) @@ -3221,7 +3166,7 @@ def generate( f"but is {generation_config.num_return_sequences}." ) - # 11. run greedy search + # 10. run greedy search outputs = self._sample( input_ids, logits_processor=logits_processor, @@ -3233,7 +3178,7 @@ def generate( ) elif is_sample_gen_mode: - # 11. prepare logits warper + # 10. prepare logits warper logits_warper = self._get_logits_warper(generation_config, device=input_ids.device) # expand input_ids with `num_return_sequences` additional sequences per batch @@ -3244,7 +3189,7 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, From ccae5a9f5fc113e5eaf30ec7da9ae26a4f690397 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 13:59:25 +0200 Subject: [PATCH 32/62] imports and constants Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 3e6cd0d..821300a 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -26,9 +26,20 @@ from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding from transformers.activations import ACT2FN from transformers.generation.configuration_utils import GenerationConfig -from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, AttentionMaskConverter +from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList +from transformers.cache_utils import ( + Cache, + DynamicCache, + EncoderDecoderCache, + HQQQuantizedCache, + HybridCache, + QuantizedCacheConfig, + QuantoQuantizedCache, + SlidingWindowCache, + StaticCache, +) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,6 +53,8 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + is_hqq_available, + is_quanto_available, ) import torch.nn.functional as F from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available @@ -75,6 +88,9 @@ ] +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache} +QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} + def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.""" From a9f75d54aaf1a8515f3bc08888ffbbad4b3d69b3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:03:11 +0200 Subject: [PATCH 33/62] attention modifications to handle static cach Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 245 ++++++++++++++---------------- 1 file changed, 116 insertions(+), 129 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 821300a..43ec0c7 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -336,6 +336,7 @@ def __init__( bias: bool = True, is_causal: bool = False, rope_embeddings : bool = False, + layer_idx: Optional[int] = None, config: Optional[ParlerTTSDecoderConfig] = None, ): super().__init__() @@ -356,6 +357,14 @@ def __init__( self.is_decoder = is_decoder self.is_causal = is_causal + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -373,87 +382,65 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len = hidden_states.shape[:2] - + # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - query_states = self._shape_query(query_states, tgt_len, bsz) + query_states = self._shape_query(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin) - current_states = key_value_states if key_value_states is not None else hidden_states + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) - + if not is_cross_attention: # cached key states already have rope applied - only apply to new state key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - # repeat k/v heads if n_kv_heads < n_heads + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -463,35 +450,25 @@ def forward( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -527,62 +504,70 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # ParlerTTSFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("ParlerTTSFlashAttention2 attention does not support output_attentions") - + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, q_len = hidden_states.shape[:2] + bsz, tgt_len = hidden_states.shape[:2] # get query proj - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) - current_states = key_value_states if key_value_states is not None else hidden_states + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1,2) - value_states = past_key_value[1].transpose(1,2) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: - key_states = self.k_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) - + key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) + if not is_cross_attention: # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states - - if past_key_value is not None: - key_states = torch.cat([past_key_value[0].transpose(1,2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1,2), value_states], dim=1) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1,2), value_states.transpose(1,2)) + key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -610,10 +595,10 @@ def forward( value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: @@ -726,12 +711,13 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, cos: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: @@ -747,6 +733,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -762,39 +749,39 @@ def forward( if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin) - current_states = key_value_states if key_value_states is not None else hidden_states + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) - + if not is_cross_attention: # cached key states already have rope applied - only apply to new state key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -803,7 +790,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, @@ -812,7 +799,7 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal, From 89e50d521aa6798800977d02c76408fd35c2966e Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:03:49 +0200 Subject: [PATCH 34/62] decoder layer modification to handle static cache Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 43ec0c7..2f697df 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -830,7 +830,7 @@ def forward( } class ParlerTTSDecoderLayer(nn.Module): - def __init__(self, config: ParlerTTSDecoderConfig): + def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.hidden_size @@ -843,6 +843,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): is_causal=True, bias=False, rope_embeddings=config.rope_embeddings, + layer_idx=layer_idx, config=config, ) self.dropout = config.dropout @@ -863,6 +864,7 @@ def __init__(self, config: ParlerTTSDecoderConfig): is_decoder=True, bias=False, rope_embeddings=config.rope_embeddings, + layer_idx=layer_idx, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -880,9 +882,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: @@ -909,32 +912,24 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - - hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, cos=cos, sin=sin, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -942,14 +937,14 @@ def forward( cos=cos, sin=sin, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value + # add cross-attn to positions 1 of present_key_value tuple + present_key_value = (present_key_value, cross_attn_present_key_value) # Fully Connected residual = hidden_states From fb750fe107b0565a93f85e1c08343fab2d8390ed Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:05:08 +0200 Subject: [PATCH 35/62] ParlerTTSPreTrainedModel modifs to handle static cache Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 2f697df..5bbf977 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -979,6 +979,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1076,14 +1078,18 @@ def _init_weights(self, module): `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are + four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and + in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or + when `config.use_cache=True` + + Two formats are allowed: + - An [`~cache_utils.EncoderDecoderCache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. @@ -1129,6 +1135,9 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache + in the correct position and to infer the complete sequence length. """ MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" From 41fa4fd7b4018a62c2a2410b3a895449c174491d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:05:42 +0200 Subject: [PATCH 36/62] ParlerTTSDecoder modifs to handle static cache Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 173 +++++++++++++++++++++++------- 1 file changed, 137 insertions(+), 36 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 5bbf977..cdb085c 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1260,8 +1260,9 @@ def __init__(self, config: ParlerTTSDecoderConfig): max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) - - self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ParlerTTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation encoder_attn_implementation = config._attn_implementation @@ -1296,6 +1297,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position=None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1318,15 +1320,47 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) # if prompt_hidden_states, fuse to inputs_embeds and update input shape if prompt_hidden_states is not None: inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) + # this step is done at first generation step + # cache_position is initialized in _get_initial_cache_position (GenerationMixin method), + # yet this method is called before this one. + # It is thus necessary to hardcode cache_position to the correct value after prepending prompt_hidden_states + cur_len = inputs_embeds.shape[1] + cache_position = torch.arange(0, cur_len, device=input_ids.device) + + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) # NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings # NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask @@ -1335,9 +1369,6 @@ def forward( if prompt_attention_mask is not None and attention_mask is not None: attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) elif prompt_attention_mask is not None: - logger.warning_once( - "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." - ) if past_key_values is None: attention_mask = torch.cat( [ @@ -1395,24 +1426,14 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) - if self.attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: if self.encoder_attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None @@ -1437,12 +1458,10 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1460,13 +1479,11 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, - attention_mask, + causal_mask, cos, sin, encoder_hidden_states, @@ -1476,11 +1493,12 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, cos=cos, sin=sin, encoder_hidden_states=encoder_hidden_states, @@ -1489,15 +1507,13 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1510,7 +1526,11 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = past_key_values if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( v @@ -1524,6 +1544,87 @@ def forward( attentions=all_self_attns, cross_attentions=all_cross_attentions, ) + +# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask @add_start_docstrings( From 5a484f8c13e107d3d70683a3e2c61d6b33c5f314 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:07:23 +0200 Subject: [PATCH 37/62] ParlerTTSModel + ParlerTTSForCausalLM modfis Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index cdb085c..36f25d8 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1660,12 +1660,13 @@ def forward( prompt_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1691,6 +1692,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1761,6 +1763,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): @@ -1788,6 +1791,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1848,6 +1852,8 @@ def prepare_inputs_for_generation( past_key_values=None, use_cache=True, delay_pattern_mask=None, + cache_position=None, + inputs_embeds=None, **kwargs, ): if delay_pattern_mask is None: @@ -1876,7 +1882,7 @@ def prepare_inputs_for_generation( prompt_hidden_states = None return { - "input_ids": input_ids, + "input_ids": input_ids.contiguous(), "attention_mask": attention_mask, "position_ids": position_ids, "encoder_hidden_states": encoder_hidden_states, @@ -1887,6 +1893,8 @@ def prepare_inputs_for_generation( "cross_attn_head_mask": cross_attn_head_mask, "past_key_values": past_key_values, "use_cache": use_cache, + "cache_position": cache_position, + "inputs_embeds": inputs_embeds } # Ignore copy From c5da07e2af674163c45dc5cda118a7b4af2c579c Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:08:26 +0200 Subject: [PATCH 38/62] ParlerTTSForConditionalGeneration modifs Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 36f25d8..6aa0ff0 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -2203,6 +2203,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + def __init__( self, @@ -2580,7 +2583,7 @@ def forward( decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, - past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, prompt_input_ids: Optional[torch.FloatTensor] = None, @@ -2592,6 +2595,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" @@ -2731,6 +2735,7 @@ def forward( past_key_values=past_key_values, return_dict=return_dict, labels=labels, + cache_position=cache_position, **kwargs_decoder, ) From 8c780efe78d09da2f1455becf1f8460c19262d6d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:09:39 +0200 Subject: [PATCH 39/62] decoder_attention_mask for static cache Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 51 +++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 6aa0ff0..e405b61 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -2768,6 +2768,8 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, decoder_delay_pattern_mask=None, + cache_position=None, + inputs_embeds=None, **kwargs, ): if decoder_delay_pattern_mask is None: @@ -2781,9 +2783,18 @@ def prepare_inputs_for_generation( # apply the delay pattern mask decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) + past_length = 0 if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + if past_key_values.get_seq_length() > 0: + # we only want to use prompt signal in the 1st generation step + prompt_hidden_states = None + else: + past_length = past_key_values[0][0].shape[2] + # we only want to use prompt signal in the 1st generation step + prompt_hidden_states = None + # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length @@ -2792,16 +2803,42 @@ def prepare_inputs_for_generation( remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + + if decoder_attention_mask is None and prompt_attention_mask is not None: + input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1]) + bsz, _, seq_len = input.shape + input_shape = (bsz, seq_len) - # if prompt_cross_attention, - # we only want to use prompt signal in the 1st generation step - prompt_hidden_states = None + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + logger.warning_once( + "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." + ) + if past_key_values is None or (isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0): + decoder_attention_mask = torch.ones(input_shape, device=self.device, dtype=decoder_input_ids.dtype) + elif prompt_attention_mask is not None: + # In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch + # to be able to prepend the prompt attention mask. + # Since we generate token per token, we can recompute the generated length from the information we have. + generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 + decoder_attention_mask = torch.ones((input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, + "decoder_input_ids": decoder_input_ids.contiguous(), "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, @@ -2810,6 +2847,8 @@ def prepare_inputs_for_generation( "prompt_hidden_states": prompt_hidden_states, "prompt_attention_mask": prompt_attention_mask, "use_cache": use_cache, + "cache_position": cache_position, + "inputs_embeds": inputs_embeds } def _prepare_decoder_input_ids_for_generation( From afa18a3e44b8e915e332a345b7e1478c6d8e91cd Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:11:30 +0200 Subject: [PATCH 40/62] create inputs_embeds early to have a good cache initialization Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index e405b61..a4eddbe 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -2896,6 +2896,16 @@ def _prepare_decoder_input_ids_for_generation( ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask + if not self.prompt_cross_attention: + prompt_hidden_states = model_kwargs["prompt_hidden_states"] + num_codebooks = self.decoder.num_codebooks + input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1]) + inputs_embeds = sum( + [self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)] + ) + inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) + model_kwargs["inputs_embeds"] = inputs_embeds + return decoder_input_ids, model_kwargs def _prepare_text_encoder_kwargs_for_generation( From 45d0fbb90daed02217ddcb5f1e87badcbe7ee6c3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:12:06 +0200 Subject: [PATCH 41/62] _get_cache method Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 122 ++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index a4eddbe..d6ecc98 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3094,6 +3094,128 @@ def _get_decoder_start_token_id( raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) + + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: + """ + Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache. + + Returns the resulting cache object. + """ + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + if hasattr(self, "_cache"): + cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache + + if cache_implementation == "sliding_window": + max_cache_len = min(self.config.sliding_window, max_cache_len) + + need_new_cache = ( + not hasattr(self, "_cache") + or (not isinstance(cache_to_check, cache_cls)) + or cache_to_check.max_batch_size != max_batch_size + or cache_to_check.max_cache_len < max_cache_len + ) + + if requires_cross_attention_cache and hasattr(self, "_cache"): + need_new_cache = ( + need_new_cache + or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] + ) + + if need_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + cache_kwargs = { + "config": self.config.decoder, + "max_batch_size": max_batch_size, + "max_cache_len": max_cache_len, + "device": self.device, + "dtype": cache_dtype, + } + self._cache = cache_cls(**cache_kwargs) + if requires_cross_attention_cache: + encoder_kwargs = cache_kwargs.copy() + encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + config_cross_attention_cache = copy.deepcopy(self.config.decoder) + config_cross_attention_cache.update( + {"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads} + ) + encoder_kwargs["config"] = config_cross_attention_cache + self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) + else: + self._cache.reset() + return self._cache + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and + # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. + if input_name == "input_ids" and "inputs_embeds" in model_kwargs: + if not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = "inputs_embeds" in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + + return inputs, input_name, model_kwargs def freeze_encoders(self, freeze_text_encoder=True): if freeze_text_encoder: From 054b7519acb1739f91de99f0a8b073737f7e7f8f Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:12:26 +0200 Subject: [PATCH 42/62] init the cache Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 75 +++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index d6ecc98..d193d69 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3409,6 +3409,81 @@ def generate( f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) + + use_dynamic_cache_by_default = False + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + if not self.prompt_cross_attention: + # when we prepend prompt_hidden_state to inputs_embeds, max_cache_len needs to be actualised + # generation_config.max_length has already been increased by input_ids_seq_length which is + # already counted in input_embeds_seq_length so we remove it + input_embeds_seq_length = model_kwargs["inputs_embeds"].shape[1] + max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_seq_length + else: + max_cache_len = self.generation_config.max_length + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + max_cache_len, + model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs["past_key_values"] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + past = model_kwargs.get("past_key_values", None) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + if past is None: + model_kwargs["past_key_values"] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + model_kwargs["past_key_values"] = ( + DynamicCache.from_legacy_cache(past) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(past) + ) + use_dynamic_cache_by_default = True # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( From 11b693fb38b1872d47860a32acf8248197723528 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 14:13:13 +0200 Subject: [PATCH 43/62] ensure good device Co-Authored-By: sang-nguyen-ts --- parler_tts/modeling_parler_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index d193d69..547cccf 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -252,7 +252,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): bsz, seq_len, _ = input_ids.size() # Create the position ids from the input token ids. - position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) + position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device) # expand embeddings if needed if seq_len > self.weights.size(0): self.make_weights(seq_len + self.offset, self.embedding_dim) From 6af19d646747b3731198aa4f3d94fe6ac7a3775e Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 22 Jul 2024 18:06:51 +0200 Subject: [PATCH 44/62] pin tfrms version Co-Authored-By: sang-nguyen-ts --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index cea2faf..813a663 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ _deps = [ - "transformers>=4.39.0", + "transformers @ git+https://github.com/huggingface/transformers@72fb02c47dbbe1999ae105319f24631cad6e2e00", "torch", "sentencepiece", "descript-audio-codec", @@ -60,7 +60,7 @@ packages=setuptools.find_packages(), install_requires=_deps, extras_require={ - "dev": [_extras_dev_deps], - "train": [_extras_training_deps], + "dev": _extras_dev_deps, + "train": _extras_training_deps, }, -) +) \ No newline at end of file From 682ca70bb6ee401e629c7a0ab6309e72fe23b36b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 24 Jul 2024 19:40:13 +0200 Subject: [PATCH 45/62] fix attention_mask FA2 --- parler_tts/modeling_parler_tts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 547cccf..56a6dff 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -565,10 +565,6 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. @@ -595,7 +591,7 @@ def forward( value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout + query_states, key_states, value_states, attention_mask, tgt_len, dropout=self.dropout ) attn_output = attn_output.reshape(bsz, tgt_len, -1) From 024a354ac9fb8d74e4334bf01601f68d5fbb0f3c Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 24 Jul 2024 21:09:44 +0200 Subject: [PATCH 46/62] remove unnecessary method --- parler_tts/modeling_parler_tts.py | 65 ------------------------------- 1 file changed, 65 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 56a6dff..84f2327 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3148,71 +3148,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l self._cache.reset() return self._cache - def _prepare_model_inputs( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[torch.Tensor] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: - """ - This function extracts the model-specific `inputs` for generation. - """ - # 1. retrieve all kwargs that are non-None or non-model input related. - # some encoder-decoder models have different names for model and encoder - if ( - self.config.is_encoder_decoder - and hasattr(self, "encoder") - and self.encoder.main_input_name != self.main_input_name - ): - input_name = self.encoder.main_input_name - else: - input_name = self.main_input_name - - model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} - - # 2. check whether model_input_name is passed as kwarg - # if yes and `inputs` is None use kwarg inputs - inputs_kwarg = model_kwargs.pop(input_name, None) - if inputs_kwarg is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " - f"Make sure to either pass {inputs} or {input_name}=..." - ) - elif inputs_kwarg is not None: - inputs = inputs_kwarg - - # 3. In the presence of `inputs_embeds` for text models: - # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model - # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with - # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) - # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and - # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. - if input_name == "input_ids" and "inputs_embeds" in model_kwargs: - if not self.config.is_encoder_decoder: - has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(self.prepare_inputs_for_generation).parameters.keys() - ) - if not has_inputs_embeds_forwarding: - raise ValueError( - f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " - "doesn't have its forwarding implemented. See the GPT2 implementation for an example " - "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" - ) - # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of - # the attention mask) can rely on the actual model input. - model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs=model_kwargs - ) - else: - if inputs is not None: - raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - - # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) - - return inputs, input_name, model_kwargs - def freeze_encoders(self, freeze_text_encoder=True): if freeze_text_encoder: for param in self.text_encoder.parameters(): From a097aa4c1abccdf2d47849db4833a8b7c0bc4988 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 24 Jul 2024 21:11:45 +0200 Subject: [PATCH 47/62] Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- parler_tts/modeling_parler_tts.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 84f2327..f645997 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3369,31 +3369,10 @@ def generate( model_kwargs, ) elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( + raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - if cache_config.backend == "quanto" and not is_quanto_available(): - raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" + "cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts" ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs["past_key_values"] = cache_class(cache_config) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): From ecd06c14678be35fa642fd0eaf556a735d19c8e1 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 24 Jul 2024 21:12:28 +0200 Subject: [PATCH 48/62] Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- parler_tts/modeling_parler_tts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index f645997..e8dfcdd 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -88,8 +88,7 @@ ] -NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache} -QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where From ad25e2ba355ea47a42e85feff77e8dab8039ae1a Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 24 Jul 2024 21:15:16 +0200 Subject: [PATCH 49/62] remove unnecessary imports --- parler_tts/modeling_parler_tts.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index e8dfcdd..dbb0179 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -28,15 +28,11 @@ from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, AttentionMaskConverter -from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList +from transformers.generation.logits_process import LogitsProcessorList from transformers.cache_utils import ( Cache, DynamicCache, EncoderDecoderCache, - HQQQuantizedCache, - HybridCache, - QuantizedCacheConfig, - QuantoQuantizedCache, SlidingWindowCache, StaticCache, ) @@ -53,13 +49,10 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, - is_hqq_available, - is_quanto_available, ) import torch.nn.functional as F from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available - from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .dac_wrapper import DACConfig, DACModel from transformers import AutoConfig, AutoModel From f29392bccafece9630eec14c2e0b6da02a990fe5 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 25 Jul 2024 17:27:02 +0200 Subject: [PATCH 50/62] replace the hardcoded cache_position with a more elegant approach --- parler_tts/modeling_parler_tts.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index dbb0179..36f0c92 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1314,12 +1314,6 @@ def forward( # if prompt_hidden_states, fuse to inputs_embeds and update input shape if prompt_hidden_states is not None: inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) - # this step is done at first generation step - # cache_position is initialized in _get_initial_cache_position (GenerationMixin method), - # yet this method is called before this one. - # It is thus necessary to hardcode cache_position to the correct value after prepending prompt_hidden_states - cur_len = inputs_embeds.shape[1] - cache_position = torch.arange(0, cur_len, device=input_ids.device) return_legacy_cache = False return_self_attention_cache = False @@ -2797,7 +2791,12 @@ def prepare_inputs_for_generation( past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device ) elif use_cache: - cache_position = cache_position[-decoder_input_ids.shape[1] :] + cur_len = decoder_input_ids.shape[1] + if prompt_hidden_states is not None and not self.prompt_cross_attention: + # meaning we are in 1st generation step and prompt_hidden_state will be prepended + cur_len += prompt_hidden_states.shape[1] + + cache_position = cache_position[-cur_len :] if decoder_attention_mask is None and prompt_attention_mask is not None: input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1]) From d08e4eb691a87e050a82efb54805e462accf3325 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 31 Jul 2024 16:05:06 +0200 Subject: [PATCH 51/62] make style --- helpers/gradio_demo/app.py | 5 +- .../model_init_scripts/init_dummy_model.py | 12 +- .../init_dummy_model_with_encodec.py | 9 +- helpers/model_init_scripts/init_model_600M.py | 12 +- .../push_to_hub_scripts/push_dac_to_hub.py | 6 +- .../push_trained_parler_tts_to_hub.py | 4 +- parler_tts/__init__.py | 5 +- parler_tts/configuration_parler_tts.py | 2 +- parler_tts/dac_wrapper/configuration_dac.py | 2 +- parler_tts/dac_wrapper/modeling_dac.py | 11 +- parler_tts/modeling_parler_tts.py | 150 ++++++++++-------- setup.py | 3 +- training/arguments.py | 12 +- training/data.py | 24 +-- training/eval.py | 19 ++- training/run_parler_tts_training.py | 114 +++++++------ training/utils.py | 14 +- 17 files changed, 241 insertions(+), 163 deletions(-) diff --git a/helpers/gradio_demo/app.py b/helpers/gradio_demo/app.py index 2aaa31c..42b9c85 100644 --- a/helpers/gradio_demo/app.py +++ b/helpers/gradio_demo/app.py @@ -1,8 +1,9 @@ import gradio as gr import torch +from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed + device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -57,7 +58,7 @@ def gen_tts(text, description): background-color: #000000; justify-content: center; align-items: center; - border-radius: 9999px !important; + border-radius: 9999px !important; width: 13rem; margin-top: 10px; margin-left: auto; diff --git a/helpers/model_init_scripts/init_dummy_model.py b/helpers/model_init_scripts/init_dummy_model.py index 25f18e8..0ce5285 100644 --- a/helpers/model_init_scripts/init_dummy_model.py +++ b/helpers/model_init_scripts/init_dummy_model.py @@ -1,7 +1,9 @@ -from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig -from transformers import AutoConfig -import os import argparse +import os + +from transformers import AutoConfig + +from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration if __name__ == "__main__": @@ -60,8 +62,8 @@ model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.do_sample = True # True model.generation_config.guidance_scale = 1 # 3.0 - + model.config.pad_token_id = encodec_vocab_size - model.config.decoder_start_token_id = encodec_vocab_size+1 + model.config.decoder_start_token_id = encodec_vocab_size + 1 model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) diff --git a/helpers/model_init_scripts/init_dummy_model_with_encodec.py b/helpers/model_init_scripts/init_dummy_model_with_encodec.py index 4e26089..ae7555b 100644 --- a/helpers/model_init_scripts/init_dummy_model_with_encodec.py +++ b/helpers/model_init_scripts/init_dummy_model_with_encodec.py @@ -1,7 +1,10 @@ -from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig -from transformers import AutoConfig -import os import argparse +import os + +from transformers import AutoConfig + +from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/helpers/model_init_scripts/init_model_600M.py b/helpers/model_init_scripts/init_model_600M.py index eae5abc..6c3f122 100644 --- a/helpers/model_init_scripts/init_model_600M.py +++ b/helpers/model_init_scripts/init_model_600M.py @@ -1,7 +1,9 @@ -from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig -from transformers import AutoConfig -import os import argparse +import os + +from transformers import AutoConfig + +from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration if __name__ == "__main__": @@ -60,8 +62,8 @@ model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.do_sample = True # True model.generation_config.guidance_scale = 1 # 3.0 - + model.config.pad_token_id = encodec_vocab_size - model.config.decoder_start_token_id = encodec_vocab_size+1 + model.config.decoder_start_token_id = encodec_vocab_size + 1 model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/")) diff --git a/helpers/push_to_hub_scripts/push_dac_to_hub.py b/helpers/push_to_hub_scripts/push_dac_to_hub.py index 0c4000a..961f947 100644 --- a/helpers/push_to_hub_scripts/push_dac_to_hub.py +++ b/helpers/push_to_hub_scripts/push_dac_to_hub.py @@ -1,7 +1,9 @@ import dac +from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor + from parler_tts import DACConfig, DACModel -from transformers import AutoConfig, AutoModel -from transformers import EncodecFeatureExtractor + + AutoConfig.register("dac", DACConfig) AutoModel.register(DACConfig, DACModel) diff --git a/helpers/push_to_hub_scripts/push_trained_parler_tts_to_hub.py b/helpers/push_to_hub_scripts/push_trained_parler_tts_to_hub.py index 5caf54f..42e3953 100644 --- a/helpers/push_to_hub_scripts/push_trained_parler_tts_to_hub.py +++ b/helpers/push_to_hub_scripts/push_trained_parler_tts_to_hub.py @@ -1,5 +1,7 @@ +from transformers import AutoFeatureExtractor, AutoTokenizer + from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer, AutoFeatureExtractor + path = "TODO" repo_id = "parler_tts_600M" diff --git a/parler_tts/__init__.py b/parler_tts/__init__.py index 655520e..b2d01b8 100644 --- a/parler_tts/__init__.py +++ b/parler_tts/__init__.py @@ -1,7 +1,10 @@ __version__ = "0.1" +from transformers import AutoConfig, AutoModel + from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig +from .dac_wrapper import DACConfig, DACModel from .modeling_parler_tts import ( ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, @@ -9,8 +12,6 @@ build_delay_pattern_mask, ) -from .dac_wrapper import DACConfig, DACModel -from transformers import AutoConfig, AutoModel AutoConfig.register("dac", DACConfig) AutoModel.register(DACConfig, DACModel) diff --git a/parler_tts/configuration_parler_tts.py b/parler_tts/configuration_parler_tts.py index c390d8a..205aca5 100644 --- a/parler_tts/configuration_parler_tts.py +++ b/parler_tts/configuration_parler_tts.py @@ -270,7 +270,7 @@ def from_sub_models_config( # This is a property because you might want to change the codec model on the fly def sampling_rate(self): return self.audio_encoder.sampling_rate - + # Copy from musicgen @property def _attn_implementation(self): diff --git a/parler_tts/dac_wrapper/configuration_dac.py b/parler_tts/dac_wrapper/configuration_dac.py index 86b0b61..2559fdd 100644 --- a/parler_tts/dac_wrapper/configuration_dac.py +++ b/parler_tts/dac_wrapper/configuration_dac.py @@ -1,5 +1,5 @@ + from transformers import PretrainedConfig -from typing import List class DACConfig(PretrainedConfig): diff --git a/parler_tts/dac_wrapper/modeling_dac.py b/parler_tts/dac_wrapper/modeling_dac.py index e45a4e8..14292e8 100644 --- a/parler_tts/dac_wrapper/modeling_dac.py +++ b/parler_tts/dac_wrapper/modeling_dac.py @@ -1,10 +1,9 @@ import torch - +from dac.model import DAC from transformers import PreTrainedModel -from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput -from .configuration_dac import DACConfig +from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput -from dac.model import DAC +from .configuration_dac import DACConfig # model doesn't support batching yet @@ -79,7 +78,7 @@ def encode( ) for offset in range(0, input_length - step, stride): - mask = padding_mask[..., offset : offset + chunk_length].bool() + padding_mask[..., offset : offset + chunk_length].bool() frame = audio_data[:, :, offset : offset + chunk_length] scale = None @@ -134,4 +133,4 @@ def decode( return EncodecDecoderOutput(audio_values) def forward(self, tensor): - raise ValueError(f"`DACModel.forward` not implemented yet") + raise ValueError("`DACModel.forward` not implemented yet") diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 36f0c92..eada0c6 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -18,17 +18,14 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding from transformers.activations import ACT2FN -from transformers.generation.configuration_utils import GenerationConfig -from transformers.generation.stopping_criteria import StoppingCriteriaList -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, AttentionMaskConverter -from transformers.generation.logits_process import LogitsProcessorList from transformers.cache_utils import ( Cache, DynamicCache, @@ -36,6 +33,14 @@ SlidingWindowCache, StaticCache, ) +from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -50,12 +55,11 @@ logging, replace_return_docstrings, ) -import torch.nn.functional as F -from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available +from transformers.utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .dac_wrapper import DACConfig, DACModel -from transformers import AutoConfig, AutoModel + AutoConfig.register("dac", DACConfig) AutoModel.register(DACConfig, DACModel) @@ -83,6 +87,7 @@ NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} + def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.""" @@ -156,6 +161,7 @@ def build_delay_pattern_mask( input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) return input_ids, pattern_mask + # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -169,8 +175,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - @dataclass class ParlerTTSUnconditionalInput(ModelOutput): """ @@ -250,6 +254,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): self.make_weights(seq_len + self.offset, self.embedding_dim) return self.weights.index_select(0, position_ids.view(-1)).detach() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->ParlerTTS class ParlerTTSRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -286,6 +291,7 @@ def forward(self, device_type, position_ids): sin = emb.sin() return cos, sin + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -315,6 +321,7 @@ def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): x_embed = (x * cos) + (rotate_half(x) * sin) return x_embed + class ParlerTTSAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper. Modified to use GQA and MQA.""" @@ -327,7 +334,7 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - rope_embeddings : bool = False, + rope_embeddings: bool = False, layer_idx: Optional[int] = None, config: Optional[ParlerTTSDecoderConfig] = None, ): @@ -427,7 +434,7 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attention_mask is not None: # no matter the length, we just slice it @@ -462,6 +469,7 @@ def forward( return attn_output, attn_weights, past_key_value + def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() @@ -474,7 +482,6 @@ def _get_unpad_data(attention_mask): ) - # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenFlashAttention2 with Musicgen->ParlerTTS class ParlerTTSFlashAttention2(ParlerTTSAttention): """ @@ -518,7 +525,7 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) - + if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) @@ -562,7 +569,7 @@ def forward( # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - + if query_states.dtype == torch.float32 or value_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() @@ -647,7 +654,6 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) @@ -693,6 +699,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen class ParlerTTSSdpaAttention(ParlerTTSAttention): def forward( @@ -721,7 +728,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - cache_position=cache_position + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -780,7 +787,6 @@ def forward( # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -810,13 +816,13 @@ def forward( return attn_output, None, past_key_value - PARLERTTS_ATTENTION_CLASSES = { "eager": ParlerTTSAttention, "sdpa": ParlerTTSSdpaAttention, "flash_attention_2": ParlerTTSFlashAttention2, } + class ParlerTTSDecoderLayer(nn.Module): def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None): super().__init__() @@ -970,7 +976,6 @@ class ParlerTTSPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True - def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, (nn.Linear, nn.Conv1d)): @@ -1255,7 +1260,9 @@ def __init__(self, config: ParlerTTSDecoderConfig): self.attn_implementation = config._attn_implementation encoder_attn_implementation = config._attn_implementation if config.cross_attention_implementation_strategy is not None: - encoder_attn_implementation = "sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager" + encoder_attn_implementation = ( + "sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager" + ) self.encoder_attn_implementation = encoder_attn_implementation self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1393,16 +1400,17 @@ def forward( position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = torch.arange( - past_key_values_length, input_shape[1] + past_key_values_length, + past_key_values_length, + input_shape[1] + past_key_values_length, dtype=torch.long, - device=inputs_embeds.device + device=inputs_embeds.device, ) position_ids = position_ids.unsqueeze(0) # Some generation methods already pass only the last input ID if position_ids.shape[1] > input_shape[1]: - position_ids = position_ids[:, -input_shape[1]:] - + position_ids = position_ids[:, -input_shape[1] :] + cos, sin = self.rotary_emb(hidden_states.device.type, position_ids) cos, sin = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) @@ -1526,8 +1534,8 @@ def forward( attentions=all_self_attns, cross_attentions=all_cross_attentions, ) - -# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1858,7 +1866,7 @@ def prepare_inputs_for_generation( if past_key_values is not None: input_ids = input_ids[:, -1:] if position_ids is not None: - position_ids = position_ids[:, -input_ids.shape[1]:] + position_ids = position_ids[:, -input_ids.shape[1] :] # we only want to use prompt signal in the 1st generation step but keeping the attention mask prompt_hidden_states = None @@ -1876,7 +1884,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "cache_position": cache_position, - "inputs_embeds": inputs_embeds + "inputs_embeds": inputs_embeds, } # Ignore copy @@ -2187,7 +2195,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True - def __init__( self, @@ -2220,6 +2227,7 @@ def __init__( if text_encoder is None: from transformers.models.auto.modeling_auto import AutoModelForTextEncoding + text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) if audio_encoder is None: @@ -2655,11 +2663,15 @@ def forward( # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) - + if prompt_attention_mask is not None and attention_mask is None: - attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) + attention_mask = torch.ones( + encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype + ) elif attention_mask is not None and prompt_attention_mask is None: - prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) + prompt_attention_mask = torch.ones( + prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype + ) # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) @@ -2668,7 +2680,7 @@ def forward( prompt_hidden_states = None prompt_attention_mask = None - + encoder_outputs["last_hidden_state"] = encoder_hidden_states elif isinstance(encoder_outputs, tuple): @@ -2678,7 +2690,7 @@ def forward( if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( - labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id + labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id ).transpose(1, 2) elif decoder_input_ids is None and decoder_inputs_embeds is None: @@ -2771,12 +2783,12 @@ def prepare_inputs_for_generation( past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() if past_key_values.get_seq_length() > 0: # we only want to use prompt signal in the 1st generation step - prompt_hidden_states = None + prompt_hidden_states = None else: past_length = past_key_values[0][0].shape[2] # we only want to use prompt signal in the 1st generation step prompt_hidden_states = None - + # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length @@ -2785,7 +2797,7 @@ def prepare_inputs_for_generation( remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - + if cache_position is None: cache_position = torch.arange( past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device @@ -2796,8 +2808,8 @@ def prepare_inputs_for_generation( # meaning we are in 1st generation step and prompt_hidden_state will be prepended cur_len += prompt_hidden_states.shape[1] - cache_position = cache_position[-cur_len :] - + cache_position = cache_position[-cur_len:] + if decoder_attention_mask is None and prompt_attention_mask is not None: input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1]) bsz, _, seq_len = input.shape @@ -2812,14 +2824,18 @@ def prepare_inputs_for_generation( logger.warning_once( "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." ) - if past_key_values is None or (isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0): + if past_key_values is None or ( + isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0 + ): decoder_attention_mask = torch.ones(input_shape, device=self.device, dtype=decoder_input_ids.dtype) elif prompt_attention_mask is not None: # In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch # to be able to prepend the prompt attention mask. # Since we generate token per token, we can recompute the generated length from the information we have. generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 - decoder_attention_mask = torch.ones((input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype) + decoder_attention_mask = torch.ones( + (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype + ) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2835,7 +2851,7 @@ def prepare_inputs_for_generation( "prompt_attention_mask": prompt_attention_mask, "use_cache": use_cache, "cache_position": cache_position, - "inputs_embeds": inputs_embeds + "inputs_embeds": inputs_embeds, } def _prepare_decoder_input_ids_for_generation( @@ -2888,11 +2904,14 @@ def _prepare_decoder_input_ids_for_generation( num_codebooks = self.decoder.num_codebooks input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1]) inputs_embeds = sum( - [self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)] + [ + self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook]) + for codebook in range(num_codebooks) + ] ) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) model_kwargs["inputs_embeds"] = inputs_embeds - + return decoder_input_ids, model_kwargs def _prepare_text_encoder_kwargs_for_generation( @@ -2953,30 +2972,34 @@ def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): # add sinusoidal positional embedding positions = self.embed_positions(prompt_hidden_states, 0) prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) - + attention_mask = model_kwargs.get("attention_mask", None) prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None) encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state if prompt_attention_mask is not None and attention_mask is None: - attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) + attention_mask = torch.ones( + encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype + ) elif attention_mask is not None and prompt_attention_mask is None: - prompt_attention_mask = torch.ones(prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype) + prompt_attention_mask = torch.ones( + prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype + ) # concatenate text description states with prompt description states encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) if prompt_attention_mask is not None: attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) - + model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states model_kwargs["attention_mask"] = attention_mask - - # in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore. + + # in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore. model_kwargs["prompt_hidden_states"] = None - model_kwargs["prompt_attention_mask"] = None + model_kwargs["prompt_attention_mask"] = None else: model_kwargs["prompt_hidden_states"] = prompt_hidden_states - # we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly + # we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly return model_kwargs def _prepare_audio_encoder_kwargs_for_generation( @@ -3027,7 +3050,9 @@ def _prepare_audio_encoder_kwargs_for_generation( return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id).transpose(1, 2) + return shift_tokens_right( + labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id + ).transpose(1, 2) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( @@ -3081,7 +3106,7 @@ def _get_decoder_start_token_id( raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) - + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a @@ -3129,7 +3154,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] - config_cross_attention_cache = copy.deepcopy(self.config.decoder) + config_cross_attention_cache = copy.deepcopy(self.config.decoder) config_cross_attention_cache.update( {"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads} ) @@ -3138,7 +3163,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l else: self._cache.reset() return self._cache - + def freeze_encoders(self, freeze_text_encoder=True): if freeze_text_encoder: for param in self.text_encoder.parameters(): @@ -3331,8 +3356,7 @@ def generate( f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) - - use_dynamic_cache_by_default = False + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: raise ValueError( "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " @@ -3361,9 +3385,9 @@ def generate( ) elif generation_config.cache_implementation == "quantized": raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts" - ) + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts" + ) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): @@ -3377,14 +3401,12 @@ def generate( if not requires_cross_attention_cache else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) - use_dynamic_cache_by_default = True elif isinstance(past, tuple): model_kwargs["past_key_values"] = ( DynamicCache.from_legacy_cache(past) if not requires_cross_attention_cache else EncoderDecoderCache.from_legacy_cache(past) ) - use_dynamic_cache_by_default = True # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( diff --git a/setup.py b/setup.py index 813a663..52d541e 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ # limitations under the License. import os + import setuptools @@ -63,4 +64,4 @@ "dev": _extras_dev_deps, "train": _extras_training_deps, }, -) \ No newline at end of file +) diff --git a/training/arguments.py b/training/arguments.py index 79c81e3..5c916fb 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -80,9 +80,7 @@ class ModelArguments: ) attn_implementation: str = field( default="eager", - metadata={ - "help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`" - }, + metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"}, ) cross_attention_implementation_strategy: str = field( default=None, @@ -329,5 +327,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ) eval_dataloader_num_workers: Optional[int] = field( default=0, - metadata={"help": ("Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process.")}, - ) \ No newline at end of file + metadata={ + "help": ( + "Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process." + ) + }, + ) diff --git a/training/data.py b/training/data.py index 2a293e6..a1ece1f 100644 --- a/training/data.py +++ b/training/data.py @@ -1,15 +1,14 @@ import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Union, Set +from typing import Dict, List, Optional, Set, Union -import torch -import numpy as np import datasets -from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets -from transformers import AutoFeatureExtractor, AutoTokenizer -from tqdm import tqdm - +import numpy as np +import torch from accelerate import Accelerator +from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset +from tqdm import tqdm +from transformers import AutoFeatureExtractor, AutoTokenizer @dataclass @@ -31,7 +30,7 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> audios = [feature[self.audio_column_name]["array"] for feature in features] len_audio = [len(audio) for audio in audios] if self.max_length is not None: - audios = [audio[:min(l, self.max_length)] for audio, l in zip(audios, len_audio)] + audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)] # since resampling has already been performed in the 'load_multiple_datasets' function, # a fixed sampling_rate(44100hz) is passed to the feature_extractor. @@ -83,7 +82,9 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> # (bsz, seq_len, num_codebooks) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) if self.audio_max_length is not None and self.padding == "max_length": - labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100) + labels = torch.nn.functional.pad( + labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100 + ) input_ids = [{"input_ids": feature["input_ids"]} for feature in features] @@ -269,7 +270,10 @@ def load_multiple_datasets( dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) - if id_column_name is not None and dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}: + if id_column_name is not None and dataset_dict["name"] not in { + "parler-tts/mls_eng_10k", + "parler-tts/mls_eng", + }: if ( len( dataset.filter( diff --git a/training/eval.py b/training/eval.py index 1116dd0..57d8d0e 100644 --- a/training/eval.py +++ b/training/eval.py @@ -1,7 +1,14 @@ -import torch import evaluate -from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast +import torch from accelerate.utils.memory import release_memory +from transformers import ( + AutoModel, + AutoProcessor, + WhisperForConditionalGeneration, + WhisperTokenizer, + WhisperTokenizerFast, + pipeline, +) def clap_similarity(clap_model_name_or_path, texts, audios, device): @@ -16,7 +23,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): audio_features = clap.get_audio_features(clap_inputs["input_features"]) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean() - + cosine_sim = cosine_sim.to("cpu") clap.to("cpu") @@ -50,7 +57,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s normalized_references = [] for pred, ref in zip(transcriptions, prompts): - normalizer = english_normalizer if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english" else basic_normalizer + normalizer = ( + english_normalizer + if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english" + else basic_normalizer + ) norm_ref = normalizer(ref) if len(norm_ref) > 0: norm_pred = normalizer(pred["text"]) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index d474112..758b169 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -21,41 +21,42 @@ import re import sys import time -from multiprocess import set_start_method from datetime import timedelta - -from tqdm import tqdm from pathlib import Path -import torch -from torch.utils.data import DataLoader - import datasets -from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets - -from huggingface_hub import HfApi - +import torch import transformers +from accelerate import Accelerator, skip_first_batches +from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, set_seed +from accelerate.utils.memory import release_memory +from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets +from huggingface_hub import HfApi +from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser -from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.optimization import get_scheduler +from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.utils import send_example_telemetry - -from accelerate import Accelerator, skip_first_batches -from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin -from accelerate.utils.memory import release_memory - from parler_tts import ( ParlerTTSConfig, ParlerTTSForConditionalGeneration, build_delay_pattern_mask, ) - -from training.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric, load_all_codec_checkpoints, save_codec_checkpoint, get_last_codec_checkpoint_step -from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments -from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding +from training.arguments import DataTrainingArguments, ModelArguments, ParlerTTSTrainingArguments +from training.data import DataCollatorEncodecWithPadding, DataCollatorParlerTTSWithPadding, load_multiple_datasets from training.eval import clap_similarity, wer +from training.utils import ( + get_last_checkpoint, + get_last_codec_checkpoint_step, + load_all_codec_checkpoints, + log_metric, + log_pred, + rotate_checkpoints, + save_codec_checkpoint, +) + logger = logging.getLogger(__name__) @@ -79,13 +80,10 @@ def main(): if training_args.dtype == "float16": mixed_precision = "fp16" - torch_dtype = torch.float16 elif training_args.dtype == "bfloat16": mixed_precision = "bf16" - torch_dtype = torch.bfloat16 else: mixed_precision = "no" - torch_dtype = torch.float32 if data_args.pad_to_max_length and ( data_args.max_duration_in_seconds is None @@ -299,7 +297,13 @@ def main(): ) # update pad token id and decoder_start_token_id - config.decoder.update({"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy if model_args.cross_attention_implementation_strategy is not None else None}) + config.decoder.update( + { + "cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy + if model_args.cross_attention_implementation_strategy is not None + else None + } + ) config.update( { "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id, @@ -336,13 +340,10 @@ def main(): description_column_name = data_args.description_column_name prompt_column_name = data_args.prompt_column_name feature_extractor_input_name = feature_extractor.model_input_names[0] - audio_encoder_pad_token_id = config.decoder.pad_token_id audio_encoder_eos_token_id = config.decoder.eos_token_id audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id - max_length = model.generation_config.max_length num_codebooks = model.decoder.config.num_codebooks bandwidth = model_args.bandwidth - attn_implementation = model_args.attn_implementation # Freeze Encoders model.freeze_encoders(model_args.freeze_text_encoder) @@ -418,7 +419,7 @@ def apply_audio_decoder(batch): output["len_audio"] = len_audio # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) output["labels"] = labels.squeeze(0).transpose(1, 2) - + # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate max_length = len_audio.max() if padding != "max_length" else max_target_length output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length @@ -426,7 +427,7 @@ def apply_audio_decoder(batch): # (1, codebooks, seq_len) where seq_len=1 bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id - + def postprocess_dataset(labels): # (1, codebooks, seq_len) labels = torch.tensor(labels).unsqueeze(0) @@ -454,7 +455,7 @@ def postprocess_dataset(labels): # we also remove the last timestampts (full of PAD) output = {"labels": labels[:, 1:]} return output - + for split in vectorized_datasets: data_loader = DataLoader( raw_datasets[split], @@ -465,7 +466,7 @@ def postprocess_dataset(labels): ) data_loader = accelerator.prepare(data_loader) total_inference_steps = len(data_loader) - + start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split)) accelerator.wait_for_everyone() if start_step > 0: @@ -477,7 +478,7 @@ def postprocess_dataset(labels): all_generated_labels = [] all_lens = [] if start_step < total_inference_steps: - for (i, batch) in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)): + for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)): cur_step = start_step + i generate_labels = apply_audio_decoder(batch) generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) @@ -491,8 +492,10 @@ def postprocess_dataset(labels): all_generated_labels.extend(lab) all_lens.extend(lens) - - if ((cur_step+1) % data_args.save_codec_steps == 0) or (cur_step == total_inference_steps - 1): + + if ((cur_step + 1) % data_args.save_codec_steps == 0) or ( + cur_step == total_inference_steps - 1 + ): tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) tmp_labels = tmp_labels.map( postprocess_dataset, @@ -500,13 +503,15 @@ def postprocess_dataset(labels): input_columns=["labels"], desc="Postprocessing labeling", ) - save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step) + save_codec_checkpoint( + os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step + ) all_generated_labels = [] all_lens = [] - + accelerator.wait_for_everyone() - - if accelerator.is_main_process and len(all_generated_labels) > 0: + + if accelerator.is_main_process and len(all_generated_labels) > 0: tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) tmp_labels = tmp_labels.map( postprocess_dataset, @@ -523,7 +528,9 @@ def postprocess_dataset(labels): accelerator.wait_for_everyone() with accelerator.local_main_process_first(): - tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(range(len(vectorized_datasets[split]))) + tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select( + range(len(vectorized_datasets[split])) + ) logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}") vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) @@ -651,7 +658,7 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): steps_per_epoch = total_train_steps if training_args.eval_steps is None: - logger.info(f"eval_steps is not set, evaluating at the end of each epoch") + logger.info("eval_steps is not set, evaluating at the end of each epoch") eval_steps = steps_per_epoch else: eval_steps = training_args.eval_steps @@ -817,7 +824,11 @@ def train_step( config.text_encoder.hidden_size != config.decoder.hidden_size and config.decoder.cross_attention_hidden_size is None ): - encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states) if training_args.parallel_mode.value != "distributed" else model.module.enc_to_dec_proj(encoder_hidden_states) + encoder_hidden_states = ( + model.enc_to_dec_proj(encoder_hidden_states) + if training_args.parallel_mode.value != "distributed" + else model.module.enc_to_dec_proj(encoder_hidden_states) + ) if batch.get("attention_mask", None) is not None: encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None] @@ -857,7 +868,11 @@ def eval_step( config.text_encoder.hidden_size != config.decoder.hidden_size and config.decoder.cross_attention_hidden_size is None ): - encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states) if training_args.parallel_mode.value != "distributed" else model.module.enc_to_dec_proj(encoder_hidden_states) + encoder_hidden_states = ( + model.enc_to_dec_proj(encoder_hidden_states) + if training_args.parallel_mode.value != "distributed" + else model.module.enc_to_dec_proj(encoder_hidden_states) + ) if batch.get("attention_mask", None) is not None: encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None] @@ -865,7 +880,6 @@ def eval_step( encoder_outputs.last_hidden_state = encoder_hidden_states batch["encoder_outputs"] = encoder_outputs - with torch.no_grad(): outputs = eval_model(**batch) # CE (data) loss @@ -875,12 +889,14 @@ def eval_step( def generate_step(batch, accelerator): batch.pop("decoder_attention_mask", None) - eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True) # (attn_implementation!="flash_attention_2")) + eval_model = accelerator.unwrap_model( + model, keep_fp32_wrapper=True + ) # (attn_implementation!="flash_attention_2")) if training_args.torch_compile: # if the model is compiled, we use the original model bc compile is not compatible with .generate eval_model = model._orig_mod - # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision. + # since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision. # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))): output_audios = eval_model.generate(**batch, **gen_kwargs) output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0) @@ -995,7 +1011,7 @@ def generate_step(batch, accelerator): for batch in tqdm( validation_dataloader, - desc=f"Evaluating - Inference ...", + desc="Evaluating - Inference ...", position=2, disable=not accelerator.is_local_main_process, ): @@ -1017,7 +1033,7 @@ def generate_step(batch, accelerator): # generation for batch in tqdm( validation_dataloader, - desc=f"Evaluating - Generation ...", + desc="Evaluating - Generation ...", position=2, disable=not accelerator.is_local_main_process, ): @@ -1079,7 +1095,9 @@ def generate_step(batch, accelerator): ) # release eval batch and relax metrics - eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric) + eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory( + eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric + ) if training_args.predict_with_generate: generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts) diff --git a/training/utils.py b/training/utils.py index 8dc2e0a..75d28c1 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,13 +1,14 @@ import os import re import shutil -from pathlib import Path from dataclasses import field +from pathlib import Path from typing import Dict, List import torch +from datasets import concatenate_datasets, load_from_disk from wandb import Audio -from datasets import load_from_disk, concatenate_datasets + def list_field(default=None, metadata=None): return field(default_factory=lambda: default, metadata=metadata) @@ -17,6 +18,7 @@ def list_field(default=None, metadata=None): CHECKPOINT_CODEC_PREFIX = "checkpoint" _RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$") + def get_last_checkpoint(folder): content = os.listdir(folder) checkpoints = [ @@ -66,10 +68,12 @@ def save_codec_checkpoint(output_dir, dataset, step): output_path = os.path.join(output_dir, checkpoint_path) dataset.save_to_disk(output_path) + def load_codec_checkpoint(checkpoint_path): dataset = load_from_disk(checkpoint_path) return dataset + def sorted_codec_checkpoints(output_dir=None) -> List[str]: """Helper function to sort saved checkpoints from oldest to newest.""" ordering_and_checkpoint_path = [] @@ -85,6 +89,7 @@ def sorted_codec_checkpoints(output_dir=None) -> List[str]: checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] return checkpoints_sorted + def load_all_codec_checkpoints(output_dir=None) -> List[str]: """Helper function to load and concat all checkpoints.""" checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir) @@ -101,13 +106,16 @@ def get_last_codec_checkpoint_step(folder) -> int: checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None] if len(checkpoints) == 0: return 0 - last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))) + last_checkpoint = os.path.join( + folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0])) + ) # Find num steps saved state string pattern pattern = r"checkpoint-(\d+)" match = re.search(pattern, last_checkpoint) cur_step = int(match.group(1)) return cur_step + def log_metric( accelerator, metrics: Dict, From 0e46d0bb8d5cce15b42ce03bb568bda7164e7bcf Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 1 Aug 2024 12:56:43 +0200 Subject: [PATCH 52/62] unpin transformers --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 52d541e..fad9be7 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ _deps = [ - "transformers @ git+https://github.com/huggingface/transformers@72fb02c47dbbe1999ae105319f24631cad6e2e00", + "transformers", "torch", "sentencepiece", "descript-audio-codec", From 43764cd8ad7f39fc9a0c14aeed82a538c5c1376b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 1 Aug 2024 14:06:08 +0200 Subject: [PATCH 53/62] pin transformers --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fad9be7..c5d768e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ _deps = [ - "transformers", + "transformers>=4.43.0,<=4.43.3", "torch", "sentencepiece", "descript-audio-codec", From b5e25f01e112f81f3639aab1de0981bb998387c2 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 1 Aug 2024 16:09:49 +0200 Subject: [PATCH 54/62] pin torch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c5d768e..4cbb076 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ _deps = [ "transformers>=4.43.0,<=4.43.3", - "torch", + "torch>=2.3.0", "sentencepiece", "descript-audio-codec", ] From 441bafd3eaa509d4cfa30f36f90e009c713b8f11 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 1 Aug 2024 16:23:07 +0200 Subject: [PATCH 55/62] refactor + unpin torch --- parler_tts/modeling_parler_tts.py | 8 ++++---- setup.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 63dd2ed..a70b3bb 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -548,9 +548,9 @@ def forward( key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) - if not is_cross_attention: + if not is_cross_attention and self.rope_embeddings: # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + key_states = apply_rotary_pos_emb(key_states, cos, sin) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -763,9 +763,9 @@ def forward( key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) - if not is_cross_attention: + if not is_cross_attention and self.rope_embeddings: # cached key states already have rope applied - only apply to new state - key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states + key_states = apply_rotary_pos_emb(key_states, cos, sin) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation diff --git a/setup.py b/setup.py index 4cbb076..c5d768e 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ _deps = [ "transformers>=4.43.0,<=4.43.3", - "torch>=2.3.0", + "torch", "sentencepiece", "descript-audio-codec", ] From 45ee62ef5f79a0cd9b80e6e9076e590a9f299e56 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:26:09 +0200 Subject: [PATCH 56/62] Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- parler_tts/modeling_parler_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index a70b3bb..d9f1edb 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1878,7 +1878,7 @@ def prepare_inputs_for_generation( prompt_hidden_states = None return { - "input_ids": input_ids.contiguous(), + "input_ids": input_ids.contiguous(), # `contiguous()` needed for compilation use cases "attention_mask": attention_mask, "position_ids": position_ids, "encoder_hidden_states": encoder_hidden_states, From a294fb5654337a2906b8c2412b696932bb2209f9 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 1 Aug 2024 16:35:53 +0200 Subject: [PATCH 57/62] update training script to match 11b209e --- training/run_parler_tts_training.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 8be5250..4295ebc 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -21,22 +21,24 @@ import re import sys import time +from multiprocess import set_start_method from datetime import timedelta + +from tqdm import tqdm from pathlib import Path -import datasets import torch -import transformers -from accelerate import Accelerator, skip_first_batches -from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, set_seed -from accelerate.utils.memory import release_memory -from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets -from huggingface_hub import HfApi from torch.utils.data import DataLoader -from tqdm import tqdm + +import datasets +from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets + +from huggingface_hub import HfApi + +import transformers from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser -from transformers.optimization import get_scheduler from transformers.trainer_pt_utils import LengthGroupedSampler +from transformers.optimization import get_scheduler from transformers.utils import send_example_telemetry @@ -348,8 +350,10 @@ def main(): description_column_name = data_args.description_column_name prompt_column_name = data_args.prompt_column_name feature_extractor_input_name = feature_extractor.model_input_names[0] + audio_encoder_pad_token_id = config.decoder.pad_token_id audio_encoder_eos_token_id = config.decoder.eos_token_id audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id + max_length = model.generation_config.max_length num_codebooks = model.decoder.config.num_codebooks bandwidth = model_args.bandwidth attn_implementation = model_args.attn_implementation @@ -707,7 +711,7 @@ def compute_metrics( steps_per_epoch = total_train_steps if training_args.eval_steps is None: - logger.info("eval_steps is not set, evaluating at the end of each epoch") + logger.info(f"eval_steps is not set, evaluating at the end of each epoch") eval_steps = steps_per_epoch else: eval_steps = training_args.eval_steps @@ -1058,7 +1062,7 @@ def generate_step(batch, accelerator): for batch in tqdm( validation_dataloader, - desc="Evaluating - Inference ...", + desc=f"Evaluating - Inference ...", position=2, disable=not accelerator.is_local_main_process, ): @@ -1080,7 +1084,7 @@ def generate_step(batch, accelerator): # generation for batch in tqdm( validation_dataloader, - desc="Evaluating - Generation ...", + desc=f"Evaluating - Generation ...", position=2, disable=not accelerator.is_local_main_process, ): From 824b183dda6a29af0bf40e6d72d7a82e5d782bf1 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:37:15 +0200 Subject: [PATCH 58/62] Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- parler_tts/modeling_parler_tts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index d9f1edb..2c59896 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -397,8 +397,8 @@ def forward( bsz, tgt_len = hidden_states.shape[:2] # get query proj - query_states = self._shape_query(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) - + query_states = self.q_proj(hidden_states) * self.scaling + query_states = self._shape_query(query_states, tgt_len, bsz) if self.rope_embeddings: query_states = apply_rotary_pos_emb(query_states, cos, sin) From 66fdbde5694e7bd5a867398f9b3e6e83b2cee7c2 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 2 Aug 2024 00:18:00 +0200 Subject: [PATCH 59/62] ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms --- parler_tts/modeling_parler_tts.py | 73 +++++++++---------------------- 1 file changed, 20 insertions(+), 53 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 2c59896..6b65e0b 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3270,38 +3270,22 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache - requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if "encoder_outputs" not in model_kwargs: @@ -3328,40 +3312,23 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + bos_token_id=generation_config._bos_token_tensor, device=inputs_tensor.device, ) # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - logger.warning( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation." - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - logger.warning( - f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: raise ValueError( @@ -3417,8 +3384,8 @@ def generate( # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, - bos_token_id=generation_config.bos_token_id, - pad_token_id=generation_config.pad_token_id, + bos_token_id=generation_config._bos_token_tensor, + pad_token_id=generation_config._pad_token_tensor, max_length=generation_config.max_length, ) # stash the delay mask so that we don't have to recompute in each forward pass @@ -3443,7 +3410,7 @@ def generate( # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, + input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, @@ -3514,8 +3481,8 @@ def generate( # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask _, mask = self.decoder.build_delay_pattern_mask( input_ids, - bos_token_id=generation_config.bos_token_id, - pad_token_id=generation_config.pad_token_id, + bos_token_id=generation_config._bos_token_tensor, + pad_token_id=generation_config._pad_token_tensor, max_length=output_ids.shape[1], ) From dc21c875f826f8ca9f7d36fffd4b009755847c16 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 2 Aug 2024 11:10:36 +0200 Subject: [PATCH 60/62] fix input_ids_length --- parler_tts/modeling_parler_tts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 6b65e0b..bc5b6a9 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3344,10 +3344,10 @@ def generate( ) if not self.prompt_cross_attention: # when we prepend prompt_hidden_state to inputs_embeds, max_cache_len needs to be actualised - # generation_config.max_length has already been increased by input_ids_seq_length which is + # generation_config.max_length has already been increased by input_ids_length which is # already counted in input_embeds_seq_length so we remove it input_embeds_seq_length = model_kwargs["inputs_embeds"].shape[1] - max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_seq_length + max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_length else: max_cache_len = self.generation_config.max_length model_kwargs["past_key_values"] = self._get_cache( From 650f276d812639c40151e24117e9359eb9861cc3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 5 Aug 2024 11:52:55 +0200 Subject: [PATCH 61/62] warning full attention mask creation --- parler_tts/modeling_parler_tts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index bc5b6a9..805caee 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -1358,6 +1358,9 @@ def forward( if prompt_attention_mask is not None and attention_mask is not None: attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) elif prompt_attention_mask is not None: + logger.warning_once( + "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." + ) if past_key_values is None: attention_mask = torch.cat( [ From 41edc2aeca462e7de48fe8f11f9896343a8f7ae2 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 7 Aug 2024 12:19:20 +0200 Subject: [PATCH 62/62] changes for training compatibility --- parler_tts/dac_wrapper/modeling_dac.py | 2 +- parler_tts/modeling_parler_tts.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/parler_tts/dac_wrapper/modeling_dac.py b/parler_tts/dac_wrapper/modeling_dac.py index 14292e8..d3d5a44 100644 --- a/parler_tts/dac_wrapper/modeling_dac.py +++ b/parler_tts/dac_wrapper/modeling_dac.py @@ -78,7 +78,7 @@ def encode( ) for offset in range(0, input_length - step, stride): - padding_mask[..., offset : offset + chunk_length].bool() + mask = padding_mask[..., offset : offset + chunk_length].bool() frame = audio_data[:, :, offset : offset + chunk_length] scale = None diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 805caee..00de4c3 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -248,7 +248,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): bsz, seq_len, _ = input_ids.size() # Create the position ids from the input token ids. - position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device) + position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length # expand embeddings if needed if seq_len > self.weights.size(0): self.make_weights(seq_len + self.offset, self.embedding_dim) @@ -1318,8 +1318,10 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) + prepended_sequence_length = 0 # if prompt_hidden_states, fuse to inputs_embeds and update input shape if prompt_hidden_states is not None: + prepended_sequence_length = prompt_hidden_states.shape[-2] inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) return_legacy_cache = False @@ -1345,7 +1347,7 @@ def forward( if cache_position is None: cache_position = torch.arange( - past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + past_key_values_length, past_key_values_length + input_shape[1] + prepended_sequence_length, device=inputs_embeds.device ) if position_ids is None: @@ -1361,7 +1363,7 @@ def forward( logger.warning_once( "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." ) - if past_key_values is None: + if past_key_values_length == 0: attention_mask = torch.cat( [ prompt_attention_mask, @@ -2699,7 +2701,7 @@ def forward( if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( - labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id + labels, self.config.pad_token_id, self.config.decoder_start_token_id ).transpose(1, 2) elif decoder_input_ids is None and decoder_inputs_embeds is None: