diff --git a/docs/source/en/model_doc/pop2piano.mdx b/docs/source/en/model_doc/pop2piano.mdx index 0b1f44a0571feb..6f64b04419e6b0 100644 --- a/docs/source/en/model_doc/pop2piano.mdx +++ b/docs/source/en/model_doc/pop2piano.mdx @@ -36,14 +36,14 @@ of producing plausible piano covers.* Tips: 1. Pop2Piano is an Encoder-Decoder based model like T5. -2. Pop2Piano can be used to generate midi-audio files for a given audio sequence. This HuggingFace implementation allows to save midi_output as well as stereo-mix output of the audio sequence. +2. Pop2Piano can be used to generate midi-audio files for a given audio sequence. 3. Choosing different composers in `Pop2PianoForConditionalGeneration.generate()` can lead to variety of different results. -4. Please note that HuggingFace implementation of Pop2Piano(both Pop2PianoForConditionalGeneration and Pop2PianoFeatureExtractor) can only work with one raw_audio sequence at a time. So if you want to process multiple files, please feed them one by one. +4. Please note that HuggingFace implementation of Pop2Piano(both Pop2PianoForConditionalGeneration and Pop2PianoFeatureExtractor) can only work with one raw_audio sequence at a time. So if you want to process multiple files, please feed them one by one. This model was contributed by [Susnato Dhar](https://huggingface.co/susnato). The original code can be found [here](https://github.com/sweetcocoa/pop2piano). -###Example using HuggingFace Dataset:### +### Example using HuggingFace Dataset: ```python >>> from datasets import load_dataset >>> from transformers import Pop2PianoForConditionalGeneration, Pop2PianoTokenizer, Pop2PianoFeatureExtractor @@ -64,7 +64,7 @@ The original code can be found [here](https://github.com/sweetcocoa/pop2piano). >>> tokenizer_output.write("./Outputs/midi_output.mid") ``` -###Example using Your own Audio:### +### Example using Your own Audio: ```python >>> import librosa >>> from transformers import Pop2PianoFeatureExtractor, Pop2PianoForConditionalGeneration, Pop2PianoTokenizer diff --git a/src/transformers/models/pop2piano/configuration_pop2piano.py b/src/transformers/models/pop2piano/configuration_pop2piano.py index 9cc8b7ea1d35ef..058636aa78e714 100644 --- a/src/transformers/models/pop2piano/configuration_pop2piano.py +++ b/src/transformers/models/pop2piano/configuration_pop2piano.py @@ -62,8 +62,8 @@ class Pop2PianoConfig(PretrainedConfig): Arguments: vocab_size (`int`, *optional*, defaults to 2400): - Vocabulary size of the Pop2PianoForConditionalGeneration model. Defines the number of different tokens that - can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`]. + Vocabulary size of the `Pop2PianoForConditionalGeneration` model. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`]. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): @@ -94,10 +94,6 @@ class Pop2PianoConfig(PretrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). dense_act_fn (`string`, *optional*, defaults to `"relu"`): Type of Activation Function to be used in `Pop2PianoDenseActDense` and in `Pop2PianoDenseGatedActDense`. - dataset_target_length (`int`, *optional*, defaults to 256): - Determines `max_length` for transformer `generate` function along with `dataset_n_bars`. - dataset_n_bars (`int`, *optional*, defaults to 2): - Determines `max_length` for transformer `generate` function along with `dataset_target_length`. """ model_type = "pop2piano" @@ -124,8 +120,6 @@ def __init__( pad_token_id=0, eos_token_id=1, dense_act_fn="relu", - dataset_target_length=256, - dataset_n_bars=2, **kwargs, ): self.vocab_size = vocab_size @@ -148,9 +142,6 @@ def __init__( self.is_gated_act = act_info[0] == "gated" self.composer_to_feature_token = COMPOSER_TO_FEATURE_TOKEN - self.dataset_target_length = dataset_target_length - self.dataset_n_bars = dataset_n_bars - super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py index 48f90be51d9eed..d704006d74ef46 100644 --- a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py +++ b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py @@ -17,8 +17,6 @@ import warnings from typing import List, Optional, Union -import essentia -import essentia.standard import librosa import numpy as np import scipy @@ -28,7 +26,17 @@ from ...audio_utils import fram_wave, get_mel_filter_banks, stft from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging +from ...utils import OptionalDependencyNotAvailable, TensorType, is_essentia_available, logging + + +try: + if not is_essentia_available: + raise OptionalDependencyNotAvailable() +except ImportError: + raise ImportError("There was an error while importing essentia!") +else: + import essentia + import essentia.standard logger = logging.get_logger(__name__) @@ -43,35 +51,35 @@ class Pop2PianoFeatureExtractor(SequenceFeatureExtractor): Args: This class extracts rhythm and does preprocesses before being passed through the transformer model. - n_bars (`int`, *optional*, defaults to 2): - Determines `n_steps` in method `preprocess_mel`. sampling_rate (`int`, *optional*, defaults to 22050): Sample rate of audio signal. use_mel (`bool`, *optional*, defaults to `True`): Whether to preprocess for `LogMelSpectrogram` or not. For the current implementation this must be `True`. padding_value (`int`, *optional*, defaults to 0): Padding value used to pad the audio. Should correspond to silences. - n_fft (`int`, *optional*, defaults to 4096): - Size of Fast Fourier Transform, creates n_fft // 2 + 1 bins. + fft_window_size (`int`, *optional*, defaults to 4096): + Size of the window om which the Fourier transform is applied. hop_length (`int`, *optional*, defaults to 1024): - Length of hop between Short-Time Fourier Transform windows. - f_min (`float`, *optional*, defaults to 10.0): + Step between each window of the waveform. + frequency_min (`float`, *optional*, defaults to 10.0): Minimum frequency. - n_mels (`int`, *optional*, defaults to 512): - Number of mel filterbanks. + nb_mel_filters (`int`, *optional*, defaults to 512): + Number of Mel filers to generate. + n_bars (`int`, *optional*, defaults to 2): + Determines `n_steps` in method `preprocess_mel`. Per `n_step` th beat is taken from each sequence. """ model_input_names = ["input_features"] def __init__( self, - n_bars: int = 2, sampling_rate: int = 22050, use_mel: int = True, padding_value: int = 0, - n_fft: int = 4096, + fft_window_size: int = 4096, hop_length: int = 1024, - f_min: float = 10.0, - n_mels: int = 512, + frequency_min: float = 10.0, + nb_mel_filters: int = 512, + n_bars: int = 2, feature_size=None, **kwargs, ): @@ -85,18 +93,18 @@ def __init__( self.sampling_rate = sampling_rate self.use_mel = use_mel self.padding_value = padding_value - self.n_fft = n_fft + self.fft_window_size = fft_window_size self.hop_length = hop_length - self.f_min = f_min - self.n_mels = n_mels + self.frequency_min = frequency_min + self.nb_mel_filters = nb_mel_filters def log_mel_spectogram(self, sequence): """Generates MelSpectrogram then applies log base e.""" mel_fb = get_mel_filter_banks( - nb_frequency_bins=(self.n_fft // 2) + 1, - nb_mel_filters=self.n_mels, - frequency_min=self.f_min, + nb_frequency_bins=(self.fft_window_size // 2) + 1, + nb_mel_filters=self.nb_mel_filters, + frequency_min=self.frequency_min, frequency_max=float(self.sampling_rate // 2), sample_rate=self.sampling_rate, norm=None, @@ -105,9 +113,9 @@ def log_mel_spectogram(self, sequence): spectrogram = [] for seq in sequence: - window = np.hanning(self.n_fft + 1)[:-1] - framed_audio = fram_wave(seq, self.hop_length, self.n_fft) - spec = stft(framed_audio, window, fft_window_size=self.n_fft) + window = np.hanning(self.fft_window_size + 1)[:-1] + framed_audio = fram_wave(seq, self.hop_length, self.fft_window_size) + spec = stft(framed_audio, window, fft_window_size=self.fft_window_size) spec = np.abs(spec) ** 2.0 spectrogram.append(spec) @@ -172,25 +180,15 @@ def preprocess_mel( n_target_step = len(beatstep) ext_beatstep = self.extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1) - def split_audio(audio): - """ - Split audio corresponding beat intervals. Each audio's lengths are different. Because each corresponding - beat interval times are different. - """ - - batch = [] - - for i in range(0, n_target_step, n_steps): - start_idx = i - end_idx = min(i + n_steps, n_target_step) - - start_sample = int(ext_beatstep[start_idx] * self.sampling_rate) - end_sample = int(ext_beatstep[end_idx] * self.sampling_rate) - feature = audio[start_sample:end_sample] - batch.append(feature) - return batch + batch = [] + for i in range(0, n_target_step, n_steps): + start_idx = i + end_idx = min(i + n_steps, n_target_step) - batch = split_audio(audio) + start_sample = int(ext_beatstep[start_idx] * self.sampling_rate) + end_sample = int(ext_beatstep[end_idx] * self.sampling_rate) + feature = audio[start_sample:end_sample] + batch.append(feature) batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) return batch, ext_beatstep @@ -253,7 +251,7 @@ def __call__( - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. """ - warnings.warn("Please make sure to have the audio sampling_rate as 44100, to get the optimal performence!") + warnings.warn("If you are not getting optimal performence, please try audio sampling_rate as 44100.") warnings.warn( "Pop2PianoFeatureExtractor only takes one raw_audio at a time, if you want to extract features from more than a single audio then you might need to call it multiple times." ) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index f99978c1a22f2b..2c0e76aac15a47 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -47,6 +47,22 @@ logger = logging.get_logger(__name__) +_load_pop2piano_layer_norm = True + +try: + from apex.normalization import FusedRMSNorm + + _load_pop2piano_layer_norm = False + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm") +except ImportError: + # using the normal Pop2PianoLayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to Pop2PianoLayerNorm") + pass + + _CONFIG_FOR_DOC = "Pop2PianoConfig" _CHECKPOINT_FOR_DOC = "susnato/pop2piano_dev" @@ -153,19 +169,9 @@ def forward(self, hidden_states): return self.weight * hidden_states -try: - from apex.normalization import FusedRMSNorm - +if not _load_pop2piano_layer_norm: Pop2PianoLayerNorm = FusedRMSNorm # noqa - logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm") -except ImportError: - # using the normal Pop2PianoLayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to Pop2PianoLayerNorm") - pass - ALL_LAYERNORM_LAYERS.append(Pop2PianoLayerNorm) @@ -997,14 +1003,15 @@ def forward(self, feature, index_value): Pop2Piano_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: - The Pop2PianoForConditionalGeneration model was proposed in [POP2PIANO : POP AUDIO-BASED PIANO COVER - GENERATION](https://arxiv.org/pdf/2211.00895) by Jongho Choi, Kyogu Lee. It's an encoder decoder transformer - pre-trained in a text-to-text denoising generative setting. This model inherits from [`PreTrainedModel`]. Check the: - superclass documentation for the generic methods the library implements for all its model (such as downloading or - saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch - [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch - Module and refer to the PyTorch documentation for all matter related to general usage and behavior. config ([`Pop2PianoConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. @@ -1086,6 +1093,20 @@ def get_encoder(self): def get_decoder(self): return self.decoder + def get_mel_conditioner_outputs(self, input_features, composer): + # select composer randomly if not already given + composer_to_feature_token = self.config.composer_to_feature_token + if composer is None: + composer = np.random.choice(list(composer_to_feature_token.keys()), size=1)[0] + elif composer not in composer_to_feature_token.keys(): + raise ValueError( + f"Composer not found in list, Please choose from {list(composer_to_feature_token.keys())}" + ) + composer_value = composer_to_feature_token[composer] + composer_value = torch.tensor(composer_value, device=self.device) + composer_value = composer_value.repeat(input_features.shape[0]) + return self.mel_conditioner(input_features, composer_value) + @add_start_docstrings_to_model_forward(Pop2Piano_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1201,7 +1222,7 @@ def generate( inputs_embeds=None, composer="composer1", n_bars: int = 2, - max_length: int = None, + max_length: int = 256, inputs: Optional[torch.Tensor] = None, generation_config=None, **kwargs, @@ -1282,27 +1303,10 @@ def generate( if generation_config is None: generation_config = self.generation_config - # select composer randomly if not already given - composer_to_feature_token = self.config.composer_to_feature_token - if composer is None: - composer = np.random.choice(list(composer_to_feature_token.keys()), size=1)[0] - elif composer not in composer_to_feature_token.keys(): - raise ValueError( - f"Composer not found in list, Please choose from {list(composer_to_feature_token.keys())}" - ) - - n_bars = self.config.dataset_n_bars if n_bars is None else n_bars - max_length = ( - self.config.dataset_target_length * max(1, (n_bars // self.config.dataset_n_bars)) - if max_length is None - else max_length + inputs_embeds = self.get_mel_conditioner_outputs( + input_features=input_features["input_features"], composer=composer ) - composer_value = composer_to_feature_token[composer] - composer_value = torch.tensor(composer_value, device=self.device) - composer_value = composer_value.repeat(input_features["input_features"].shape[0]) - inputs_embeds = self.mel_conditioner(input_features["input_features"], composer_value) - return super().generate( inputs, generation_config, diff --git a/src/transformers/models/pop2piano/tokenization_pop2piano.py b/src/transformers/models/pop2piano/tokenization_pop2piano.py index 0c4cb6e1669adb..b1f13a25d30d7b 100644 --- a/src/transformers/models/pop2piano/tokenization_pop2piano.py +++ b/src/transformers/models/pop2piano/tokenization_pop2piano.py @@ -75,11 +75,21 @@ def __init__( vocab_size_time: int = 100, n_bars: int = 2, unk_token="-1", + eos_token="1", + pad_token="0", + bos_token="2", **kwargs, ): unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + super().__init__( unk_token=unk_token, + eos_token=eos_token, + pad_token=pad_token, + bos_token=bos_token, vocab_size_special=vocab_size_special, vocab_size_note=vocab_size_note, vocab_size_velocity=vocab_size_velocity, @@ -89,12 +99,7 @@ def __init__( ) with open(vocab_file, "rb") as t_file: - self.vocab_config = json.load(t_file) - self.token_note = self.vocab_config["TOKEN_NOTE"] - self.token_time = self.vocab_config["TOKEN_TIME"] - self.token_special = self.vocab_config["TOKEN_SPECIAL"] - self.token_velocity = self.vocab_config["TOKEN_VELOCITY"] - self.default_velocity = self.vocab_config["DEFAULT_VELOCITY"] + self.encoder = json.load(t_file) self.vocab_size_special = vocab_size_special self.vocab_size_note = vocab_size_note @@ -107,36 +112,36 @@ def vocab_size(self): return self.vocab_size_special + self.vocab_size_note + self.vocab_size_time + self.vocab_size_velocity def get_vocab(self): - return self.vocab_config + return self.encoder def _convert_id_to_token(self, token, time_idx_offset): """Decodes the tokens generated by the transformer""" if token >= (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity): - type = self.token_time + type = self.encoder["TOKEN_TIME"] value = ( token - (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity) ) + time_idx_offset elif token >= (self.vocab_size_special + self.vocab_size_note): - type = self.token_velocity + type = self.encoder["TOKEN_VELOCITY"] value = int(token - (self.vocab_size_special + self.vocab_size_note)) elif token >= self.vocab_size_special: - type = self.token_note + type = self.encoder["TOKEN_NOTE"] value = int(token - self.vocab_size_special) else: - type = self.token_special + type = self.encoder["TOKEN_SPECIAL"] value = int(token) return [type, value] - def _convert_token_to_id(self, token, token_type): - if token_type == self.token_time: + def _convert_token_to_id(self, token, token_type="3"): + if token_type == self.encoder["TOKEN_TIME"]: return self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity + token - elif token_type == self.token_velocity: + elif token_type == self.encoder["TOKEN_VELOCITY"]: return self.vocab_size_special + self.vocab_size_note + token - elif token_type == self.token_note: + elif token_type == self.encoder["TOKEN_NOTE"]: return self.vocab_size_special + token - elif token_type == self.token_special: + elif token_type == self.encoder["TOKEN_SPECIAL"]: return token else: return -1 @@ -195,17 +200,17 @@ def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): note_onsets_ready = [None for i in range(self.vocab_size_note + 1)] notes = [] for type, number in words: - if type == self.token_special: + if type == self.encoder["TOKEN_SPECIAL"]: if number == 1: break - elif type == self.token_time: + elif type == self.encoder["TOKEN_TIME"]: current_idx += number if cutoff_time_idx is not None: current_idx = min(current_idx, cutoff_time_idx) - elif type == self.token_velocity: + elif type == self.encoder["TOKEN_VELOCITY"]: current_velocity = number - elif type == self.token_note: + elif type == self.encoder["TOKEN_NOTE"]: pitch = number if current_velocity == 0: # note_offset @@ -219,7 +224,7 @@ def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): pass else: offset_idx = current_idx - notes.append([onset_idx, offset_idx, pitch, self.default_velocity]) + notes.append([onset_idx, offset_idx, pitch, self.encoder["DEFAULT_VELOCITY"]]) note_onsets_ready[pitch] = None else: # note_on @@ -233,10 +238,10 @@ def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): pass else: offset_idx = current_idx - notes.append([onset_idx, offset_idx, pitch, self.default_velocity]) + notes.append([onset_idx, offset_idx, pitch, self.encoder["DEFAULT_VELOCITY"]]) note_onsets_ready[pitch] = current_idx else: - raise ValueError + raise ValueError("Token type not understood!") for pitch, note_on in enumerate(note_onsets_ready): # force offset if no offset for each pitch @@ -247,7 +252,7 @@ def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): cutoff = max(cutoff_time_idx, note_on + 1) offset_idx = max(current_idx, cutoff) - notes.append([note_on, offset_idx, pitch, self.default_velocity]) + notes.append([note_on, offset_idx, pitch, self.encoder["DEFAULT_VELOCITY"]]) if len(notes) == 0: return [] @@ -296,7 +301,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) with open(vocab_file, "w", encoding="utf-8") as f: - f.write(json.dumps(self.vocab_config)) + f.write(json.dumps(self.encoder)) return (vocab_file,)