Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting inference to Transformers 4.42 #84

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions parler_tts/modeling_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -164,10 +164,10 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[..., 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
Expand Down Expand Up @@ -650,6 +650,10 @@ def _init_weights(self, module):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -1121,6 +1125,9 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)

outputs = self.model(
input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1464,6 +1471,7 @@ def generate(
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)

# 10. prepare stopping criteria
Expand All @@ -1479,7 +1487,7 @@ def generate(
)

# 11. run greedy search
outputs = self._greedy_search(
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
Expand All @@ -1491,7 +1499,7 @@ def generate(

elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)

# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down Expand Up @@ -2014,7 +2022,7 @@ def forward(

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
).transpose(1, 2)

elif decoder_input_ids is None and decoder_inputs_embeds is None:
Expand Down Expand Up @@ -2056,7 +2064,7 @@ def forward(
)

if not return_dict:
return decoder_outputs + (encoder_hidden_states,)
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
Expand Down Expand Up @@ -2186,6 +2194,25 @@ def _prepare_decoder_input_ids_for_generation(

return decoder_input_ids, model_kwargs

# This method is copied from musicgen_modeling.py for compatibility with Transformers 4.42.
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id

if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)

def _prepare_text_encoder_kwargs_for_generation(
self,
inputs_tensor: torch.Tensor,
Expand Down Expand Up @@ -2287,7 +2314,7 @@ def _prepare_audio_encoder_kwargs_for_generation(
return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id).transpose(1, 2)

def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
Expand Down Expand Up @@ -2446,6 +2473,9 @@ def generate(
)
batch_size = inputs_tensor.shape[0]

kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)

# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
Expand Down Expand Up @@ -2557,6 +2587,7 @@ def generate(
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)

# 10. prepare stopping criteria
Expand Down Expand Up @@ -2584,7 +2615,7 @@ def generate(

elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)

# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


_deps = [
"transformers>=4.39.0,<4.41.0",
"transformers>=4.42.0",
"torch",
"sentencepiece",
"descript-audio-codec",
Expand Down