Skip to content

Commit

Permalink
more comments and nits
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Apr 28, 2023
1 parent 861305b commit 6d28b21
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 121 deletions.
8 changes: 4 additions & 4 deletions docs/source/en/model_doc/pop2piano.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ of producing plausible piano covers.*
Tips:

1. Pop2Piano is an Encoder-Decoder based model like T5.
2. Pop2Piano can be used to generate midi-audio files for a given audio sequence. This HuggingFace implementation allows to save midi_output as well as stereo-mix output of the audio sequence.
2. Pop2Piano can be used to generate midi-audio files for a given audio sequence.
3. Choosing different composers in `Pop2PianoForConditionalGeneration.generate()` can lead to variety of different results.
4. Please note that HuggingFace implementation of Pop2Piano(both Pop2PianoForConditionalGeneration and Pop2PianoFeatureExtractor) can only work with one raw_audio sequence at a time. So if you want to process multiple files, please feed them one by one.
4. Please note that HuggingFace implementation of Pop2Piano(both Pop2PianoForConditionalGeneration and Pop2PianoFeatureExtractor) can only work with one raw_audio sequence at a time. So if you want to process multiple files, please feed them one by one.

This model was contributed by [Susnato Dhar](https://huggingface.co/susnato).
The original code can be found [here](https://github.com/sweetcocoa/pop2piano).

###Example using HuggingFace Dataset:###
### Example using HuggingFace Dataset:
```python
>>> from datasets import load_dataset
>>> from transformers import Pop2PianoForConditionalGeneration, Pop2PianoTokenizer, Pop2PianoFeatureExtractor
Expand All @@ -64,7 +64,7 @@ The original code can be found [here](https://github.com/sweetcocoa/pop2piano).
>>> tokenizer_output.write("./Outputs/midi_output.mid")
```
###Example using Your own Audio:###
### Example using Your own Audio:
```python
>>> import librosa
>>> from transformers import Pop2PianoFeatureExtractor, Pop2PianoForConditionalGeneration, Pop2PianoTokenizer
Expand Down
13 changes: 2 additions & 11 deletions src/transformers/models/pop2piano/configuration_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Pop2PianoConfig(PretrainedConfig):
Arguments:
vocab_size (`int`, *optional*, defaults to 2400):
Vocabulary size of the Pop2PianoForConditionalGeneration model. Defines the number of different tokens that
can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`].
Vocabulary size of the `Pop2PianoForConditionalGeneration` model. Defines the number of different tokens
that can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`].
d_model (`int`, *optional*, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64):
Expand Down Expand Up @@ -94,10 +94,6 @@ class Pop2PianoConfig(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
dense_act_fn (`string`, *optional*, defaults to `"relu"`):
Type of Activation Function to be used in `Pop2PianoDenseActDense` and in `Pop2PianoDenseGatedActDense`.
dataset_target_length (`int`, *optional*, defaults to 256):
Determines `max_length` for transformer `generate` function along with `dataset_n_bars`.
dataset_n_bars (`int`, *optional*, defaults to 2):
Determines `max_length` for transformer `generate` function along with `dataset_target_length`.
"""

model_type = "pop2piano"
Expand All @@ -124,8 +120,6 @@ def __init__(
pad_token_id=0,
eos_token_id=1,
dense_act_fn="relu",
dataset_target_length=256,
dataset_n_bars=2,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -148,9 +142,6 @@ def __init__(
self.is_gated_act = act_info[0] == "gated"
self.composer_to_feature_token = COMPOSER_TO_FEATURE_TOKEN

self.dataset_target_length = dataset_target_length
self.dataset_n_bars = dataset_n_bars

super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
Expand Down
84 changes: 41 additions & 43 deletions src/transformers/models/pop2piano/feature_extraction_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import warnings
from typing import List, Optional, Union

import essentia
import essentia.standard
import librosa
import numpy as np
import scipy
Expand All @@ -28,7 +26,17 @@
from ...audio_utils import fram_wave, get_mel_filter_banks, stft
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
from ...utils import OptionalDependencyNotAvailable, TensorType, is_essentia_available, logging


try:
if not is_essentia_available:
raise OptionalDependencyNotAvailable()
except ImportError:
raise ImportError("There was an error while importing essentia!")
else:
import essentia
import essentia.standard


logger = logging.get_logger(__name__)
Expand All @@ -43,35 +51,35 @@ class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
Args:
This class extracts rhythm and does preprocesses before being passed through the transformer model.
n_bars (`int`, *optional*, defaults to 2):
Determines `n_steps` in method `preprocess_mel`.
sampling_rate (`int`, *optional*, defaults to 22050):
Sample rate of audio signal.
use_mel (`bool`, *optional*, defaults to `True`):
Whether to preprocess for `LogMelSpectrogram` or not. For the current implementation this must be `True`.
padding_value (`int`, *optional*, defaults to 0):
Padding value used to pad the audio. Should correspond to silences.
n_fft (`int`, *optional*, defaults to 4096):
Size of Fast Fourier Transform, creates n_fft // 2 + 1 bins.
fft_window_size (`int`, *optional*, defaults to 4096):
Size of the window om which the Fourier transform is applied.
hop_length (`int`, *optional*, defaults to 1024):
Length of hop between Short-Time Fourier Transform windows.
f_min (`float`, *optional*, defaults to 10.0):
Step between each window of the waveform.
frequency_min (`float`, *optional*, defaults to 10.0):
Minimum frequency.
n_mels (`int`, *optional*, defaults to 512):
Number of mel filterbanks.
nb_mel_filters (`int`, *optional*, defaults to 512):
Number of Mel filers to generate.
n_bars (`int`, *optional*, defaults to 2):
Determines `n_steps` in method `preprocess_mel`. Per `n_step` th beat is taken from each sequence.
"""
model_input_names = ["input_features"]

def __init__(
self,
n_bars: int = 2,
sampling_rate: int = 22050,
use_mel: int = True,
padding_value: int = 0,
n_fft: int = 4096,
fft_window_size: int = 4096,
hop_length: int = 1024,
f_min: float = 10.0,
n_mels: int = 512,
frequency_min: float = 10.0,
nb_mel_filters: int = 512,
n_bars: int = 2,
feature_size=None,
**kwargs,
):
Expand All @@ -85,18 +93,18 @@ def __init__(
self.sampling_rate = sampling_rate
self.use_mel = use_mel
self.padding_value = padding_value
self.n_fft = n_fft
self.fft_window_size = fft_window_size
self.hop_length = hop_length
self.f_min = f_min
self.n_mels = n_mels
self.frequency_min = frequency_min
self.nb_mel_filters = nb_mel_filters

def log_mel_spectogram(self, sequence):
"""Generates MelSpectrogram then applies log base e."""

mel_fb = get_mel_filter_banks(
nb_frequency_bins=(self.n_fft // 2) + 1,
nb_mel_filters=self.n_mels,
frequency_min=self.f_min,
nb_frequency_bins=(self.fft_window_size // 2) + 1,
nb_mel_filters=self.nb_mel_filters,
frequency_min=self.frequency_min,
frequency_max=float(self.sampling_rate // 2),
sample_rate=self.sampling_rate,
norm=None,
Expand All @@ -105,9 +113,9 @@ def log_mel_spectogram(self, sequence):

spectrogram = []
for seq in sequence:
window = np.hanning(self.n_fft + 1)[:-1]
framed_audio = fram_wave(seq, self.hop_length, self.n_fft)
spec = stft(framed_audio, window, fft_window_size=self.n_fft)
window = np.hanning(self.fft_window_size + 1)[:-1]
framed_audio = fram_wave(seq, self.hop_length, self.fft_window_size)
spec = stft(framed_audio, window, fft_window_size=self.fft_window_size)
spec = np.abs(spec) ** 2.0
spectrogram.append(spec)

Expand Down Expand Up @@ -172,25 +180,15 @@ def preprocess_mel(
n_target_step = len(beatstep)
ext_beatstep = self.extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1)

def split_audio(audio):
"""
Split audio corresponding beat intervals. Each audio's lengths are different. Because each corresponding
beat interval times are different.
"""

batch = []

for i in range(0, n_target_step, n_steps):
start_idx = i
end_idx = min(i + n_steps, n_target_step)

start_sample = int(ext_beatstep[start_idx] * self.sampling_rate)
end_sample = int(ext_beatstep[end_idx] * self.sampling_rate)
feature = audio[start_sample:end_sample]
batch.append(feature)
return batch
batch = []
for i in range(0, n_target_step, n_steps):
start_idx = i
end_idx = min(i + n_steps, n_target_step)

batch = split_audio(audio)
start_sample = int(ext_beatstep[start_idx] * self.sampling_rate)
end_sample = int(ext_beatstep[end_idx] * self.sampling_rate)
feature = audio[start_sample:end_sample]
batch.append(feature)
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)

return batch, ext_beatstep
Expand Down Expand Up @@ -253,7 +251,7 @@ def __call__(
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
"""
warnings.warn("Please make sure to have the audio sampling_rate as 44100, to get the optimal performence!")
warnings.warn("If you are not getting optimal performence, please try audio sampling_rate as 44100.")
warnings.warn(
"Pop2PianoFeatureExtractor only takes one raw_audio at a time, if you want to extract features from more than a single audio then you might need to call it multiple times."
)
Expand Down
80 changes: 42 additions & 38 deletions src/transformers/models/pop2piano/modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@

logger = logging.get_logger(__name__)

_load_pop2piano_layer_norm = True

try:
from apex.normalization import FusedRMSNorm

_load_pop2piano_layer_norm = False

logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm")
except ImportError:
# using the normal Pop2PianoLayerNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to Pop2PianoLayerNorm")
pass


_CONFIG_FOR_DOC = "Pop2PianoConfig"
_CHECKPOINT_FOR_DOC = "susnato/pop2piano_dev"

Expand Down Expand Up @@ -153,19 +169,9 @@ def forward(self, hidden_states):
return self.weight * hidden_states


try:
from apex.normalization import FusedRMSNorm

if not _load_pop2piano_layer_norm:
Pop2PianoLayerNorm = FusedRMSNorm # noqa

logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm")
except ImportError:
# using the normal Pop2PianoLayerNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to Pop2PianoLayerNorm")
pass

ALL_LAYERNORM_LAYERS.append(Pop2PianoLayerNorm)


Expand Down Expand Up @@ -997,14 +1003,15 @@ def forward(self, feature, index_value):


Pop2Piano_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
The Pop2PianoForConditionalGeneration model was proposed in [POP2PIANO : POP AUDIO-BASED PIANO COVER
GENERATION](https://arxiv.org/pdf/2211.00895) by Jongho Choi, Kyogu Lee. It's an encoder decoder transformer
pre-trained in a text-to-text denoising generative setting. This model inherits from [`PreTrainedModel`]. Check the:
superclass documentation for the generic methods the library implements for all its model (such as downloading or
saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch
[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch
Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
config ([`Pop2PianoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
Expand Down Expand Up @@ -1086,6 +1093,20 @@ def get_encoder(self):
def get_decoder(self):
return self.decoder

def get_mel_conditioner_outputs(self, input_features, composer):
# select composer randomly if not already given
composer_to_feature_token = self.config.composer_to_feature_token
if composer is None:
composer = np.random.choice(list(composer_to_feature_token.keys()), size=1)[0]
elif composer not in composer_to_feature_token.keys():
raise ValueError(
f"Composer not found in list, Please choose from {list(composer_to_feature_token.keys())}"
)
composer_value = composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value, device=self.device)
composer_value = composer_value.repeat(input_features.shape[0])
return self.mel_conditioner(input_features, composer_value)

@add_start_docstrings_to_model_forward(Pop2Piano_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -1201,7 +1222,7 @@ def generate(
inputs_embeds=None,
composer="composer1",
n_bars: int = 2,
max_length: int = None,
max_length: int = 256,
inputs: Optional[torch.Tensor] = None,
generation_config=None,
**kwargs,
Expand Down Expand Up @@ -1282,27 +1303,10 @@ def generate(
if generation_config is None:
generation_config = self.generation_config

# select composer randomly if not already given
composer_to_feature_token = self.config.composer_to_feature_token
if composer is None:
composer = np.random.choice(list(composer_to_feature_token.keys()), size=1)[0]
elif composer not in composer_to_feature_token.keys():
raise ValueError(
f"Composer not found in list, Please choose from {list(composer_to_feature_token.keys())}"
)

n_bars = self.config.dataset_n_bars if n_bars is None else n_bars
max_length = (
self.config.dataset_target_length * max(1, (n_bars // self.config.dataset_n_bars))
if max_length is None
else max_length
inputs_embeds = self.get_mel_conditioner_outputs(
input_features=input_features["input_features"], composer=composer
)

composer_value = composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value, device=self.device)
composer_value = composer_value.repeat(input_features["input_features"].shape[0])
inputs_embeds = self.mel_conditioner(input_features["input_features"], composer_value)

return super().generate(
inputs,
generation_config,
Expand Down
Loading

0 comments on commit 6d28b21

Please sign in to comment.