From b9357cd3718bbb1a27cac2f20b891e95891284d1 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 23 Mar 2023 16:16:05 +0530 Subject: [PATCH] tokenization tests added --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/pop2piano.mdx | 33 +- src/transformers/__init__.py | 3 +- src/transformers/models/pop2piano/__init__.py | 2 + .../pop2piano/feature_extraction_pop2piano.py | 18 - .../pop2piano/tokenization_pop2piano.py | 423 ++++++++++++++++++ src/transformers/utils/dummy_music_objects.py | 7 + src/transformers/utils/import_utils.py | 4 +- .../test_feature_extraction_pop2piano.py | 9 +- .../pop2piano/test_modeling_pop2piano.py | 11 +- .../pop2piano/test_tokenization_pop2piano.py | 127 ++++++ 11 files changed, 604 insertions(+), 35 deletions(-) create mode 100644 src/transformers/models/pop2piano/tokenization_pop2piano.py create mode 100644 tests/models/pop2piano/test_tokenization_pop2piano.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 62bea151104743..60aa39c650b419 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -369,7 +369,7 @@ Flax), PyTorch, and/or TensorFlow. | Pix2Struct | ❌ | ❌ | ✅ | ❌ | ❌ | | PLBart | ✅ | ❌ | ✅ | ❌ | ❌ | | PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ | -| Pop2Piano | ❌ | ❌ | ✅ | ❌ | ❌ | +| Pop2Piano | ✅ | ❌ | ✅ | ❌ | ❌ | | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/pop2piano.mdx b/docs/source/en/model_doc/pop2piano.mdx index 61722c4475b956..e1e1b2f079f597 100644 --- a/docs/source/en/model_doc/pop2piano.mdx +++ b/docs/source/en/model_doc/pop2piano.mdx @@ -37,12 +37,38 @@ Tips: 1. Pop2Piano is an Encoder-Decoder based model like T5. 2. Pop2Piano can be used to generate midi-audio files for a given audio sequence. This HuggingFace implementation allows to save midi_output as well as stereo-mix output of the audio sequence. -3. Choosing different composers in Pop2PianoForConditionalGeneration.generate can lead to variety of different results. +3. Choosing different composers in `Pop2PianoForConditionalGeneration.generate()` can lead to variety of different results. 4. Please note that HuggingFace implementation of Pop2Piano(both Pop2PianoForConditionalGeneration and Pop2PianoFeatureExtractor) can only work with one raw_audio sequence at a time. So if you want to process multiple files, please feed them one by one. This model was contributed by [Susnato Dhar](https://huggingface.co/susnato). The original code can be found [here](https://github.com/sweetcocoa/pop2piano). +Example: +``` +import librosa +from transformers import Pop2PianoFeatureExtractor, Pop2PianoForConditionalGeneration, Pop2PianoTokenizer + +raw_audio, sr = librosa.load("audio.mp3", sr=44100) +model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev") +feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev") +tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev") + +model.eval() + +feature_extractor_outputs = fe(raw_audio=raw_audio, audio_sr=sr, return_tensors="pt") +model_outputs = model.generate(feature_extractor_outputs, composer="composer1") + +opt_postprocess = tokenizer(relative_tokens=model_outputs, + beatsteps=feature_extractor_outputs["beatsteps"], + ext_beatstep=feature_extractor_outputs["ext_beatstep"], + raw_audio=raw_audio, + sampling_rate=sr, + save_path="./Music/Outputs/", + audio_file_name="filename", + save_midi=True + ) +``` + ## Pop2PianoConfig @@ -59,3 +85,8 @@ The original code can be found [here](https://github.com/sweetcocoa/pop2piano). - forward - generate +## Pop2PianoTokenizer + +[[autodoc]] Pop2PianoTokenizer + - __call__ + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 557591a4e24d93..beee0e539dc671 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3443,6 +3443,7 @@ ] else: _import_structure["models.pop2piano"].append("Pop2PianoFeatureExtractor") + _import_structure["models.pop2piano"].append("Pop2PianoTokenizer") # FLAX-backed objects @@ -6554,7 +6555,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_music_objects import * else: - from .models.pop2piano import Pop2PianoFeatureExtractor + from .models.pop2piano import Pop2PianoFeatureExtractor, Pop2PianoTokenizer try: if not is_flax_available(): diff --git a/src/transformers/models/pop2piano/__init__.py b/src/transformers/models/pop2piano/__init__.py index 22cc45865ed410..90703043b173fd 100644 --- a/src/transformers/models/pop2piano/__init__.py +++ b/src/transformers/models/pop2piano/__init__.py @@ -57,6 +57,7 @@ pass else: _import_structure["feature_extraction_pop2piano"] = ["Pop2PianoFeatureExtractor"] + _import_structure["tokenization_pop2piano"] = ["Pop2PianoTokenizer"] if TYPE_CHECKING: @@ -89,6 +90,7 @@ pass else: from .feature_extraction_pop2piano import Pop2PianoFeatureExtractor + from .tokenization_pop2piano import Pop2PianoTokenizer else: import sys diff --git a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py index 2690ce1a2a1b3e..4353038e3c590a 100644 --- a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py +++ b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py @@ -51,16 +51,6 @@ class Pop2PianoFeatureExtractor(SequenceFeatureExtractor): Whether to preprocess for `LogMelSpectrogram` or not. For the current implementation this must be `True`. padding_value (`int`, *optional*, defaults to 0): Padding value used to pad the audio. Should correspond to silences. - vocab_size_special (`int`, *optional*, defaults to 4): - Number of special values. - vocab_size_note (`int`, *optional*, defaults to 128): - This represents the number of Note Values. Note values indicate a pitch event for one of the MIDI pitches. - But only the 88 pitches corresponding to piano keys are actually used. - vocab_size_velocity (`int`, *optional*, defaults to 2): - Number of Velocity tokens. - vocab_size_time (`int`, *optional*, defaults to 100): - This represents the number of Beat Shifts. Beat Shift [100 values] Indicates the relative time shift within - the segment quantized into 8th-note beats(half-beats). n_fft (`int`, *optional*, defaults to 4096): Size of Fast Fourier Transform, creates n_fft // 2 + 1 bins. hop_length (`int`, *optional*, defaults to 1024): @@ -78,10 +68,6 @@ def __init__( sampling_rate: int = 22050, use_mel: int = True, padding_value: int = 0, - vocab_size_special: int = 4, - vocab_size_note: int = 128, - vocab_size_velocity: int = 2, - vocab_size_time: int = 100, n_fft: int = 4096, hop_length: int = 1024, f_min: float = 10.0, @@ -99,10 +85,6 @@ def __init__( self.sampling_rate = sampling_rate self.use_mel = use_mel self.padding_value = padding_value - self.vocab_size_special = vocab_size_special - self.vocab_size_note = vocab_size_note - self.vocab_size_velocity = vocab_size_velocity - self.vocab_size_time = vocab_size_time self.n_fft = n_fft self.hop_length = hop_length self.f_min = f_min diff --git a/src/transformers/models/pop2piano/tokenization_pop2piano.py b/src/transformers/models/pop2piano/tokenization_pop2piano.py new file mode 100644 index 00000000000000..36f564c7e5bbe4 --- /dev/null +++ b/src/transformers/models/pop2piano/tokenization_pop2piano.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2023 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Pop2Piano.""" + +import json +import os +from typing import List, Optional, Tuple, Union + +import librosa +import numpy as np +import pretty_midi +import soundfile as sf +import torch + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "tokenizer_file": "tokenizer.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "tokenizer_file": { + "susnato/pop2piano_dev": "https://huggingface.co/susnato/pop2piano_dev/blob/main/tokenizer.json", + }, +} + + +class Pop2PianoTokenizer(PreTrainedTokenizer): + """ + Constructs a Pop2Piano tokenizer. This tokenizer does not require training. The `Pop2PianoTokenizer.__call__()` + method can only process one sequence at a time. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer + to: this superclass for more information regarding those methods. However the code does not allow that and only + supports composing from various genres. + + Args: + tokenizer_file (`str`): + Path to the tokenizer file which contains token-ids such as `TOKEN_SPECIAL`, `DEFAULT_VELOCITY`, + `PAD`, `EOS`. + vocab_size_special (`int`, *optional*, defaults to 4): + Number of special values. + vocab_size_note (`int`, *optional*, defaults to 128): + This represents the number of Note Values. Note values indicate a pitch event for one of the MIDI pitches. + But only the 88 pitches corresponding to piano keys are actually used. + vocab_size_velocity (`int`, *optional*, defaults to 2): + Number of Velocity tokens. + vocab_size_time (`int`, *optional*, defaults to 100): + This represents the number of Beat Shifts. Beat Shift [100 values] Indicates the relative time shift within + the segment quantized into 8th-note beats(half-beats). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + + def __init__( + self, + tokenizer_file, + vocab_size_special: int = 4, + vocab_size_note: int = 128, + vocab_size_velocity: int = 2, + vocab_size_time: int = 100, + n_bars: int = 2, + **kwargs, + ): + super().__init__( + **kwargs, + ) + with open(tokenizer_file, "rb") as t_file: + self.tokenizer_config = json.load(t_file) + + self.TIE = self.tokenizer_config["TIE"] + self.EOS = self.tokenizer_config["EOS"] + self.PAD = self.tokenizer_config["PAD"] + self.TOKEN_NOTE = self.tokenizer_config["TOKEN_NOTE"] + self.TOKEN_TIME = self.tokenizer_config["TOKEN_TIME"] + self.TOKEN_SPECIAL = self.tokenizer_config["TOKEN_SPECIAL"] + self.TOKEN_VELOCITY = self.tokenizer_config["TOKEN_VELOCITY"] + self.DEFAULT_VELOCITY = self.tokenizer_config["DEFAULT_VELOCITY"] + + self.vocab_size_special = vocab_size_special + self.vocab_size_note = vocab_size_note + self.vocab_size_velocity = vocab_size_velocity + self.vocab_size_time = vocab_size_time + self.n_bars = n_bars + + @property + def vocab_size(self): + return self.vocab_size_special + self.vocab_size_note + self.vocab_size_time + self.vocab_size_velocity + + def get_vocab(self): + return self.tokenizer_config + + def _convert_id_to_token(self, token, time_idx_offset): + """Decodes the tokens generated by the transformer""" + + if token >= (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity): + type, value = self.TOKEN_TIME, ( + (token - (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity)) + time_idx_offset + ) + elif token >= (self.vocab_size_special + self.vocab_size_note): + type, value = self.TOKEN_VELOCITY, (token - (self.vocab_size_special + self.vocab_size_note)) + value = int(value) + elif token >= self.vocab_size_special: + type, value = self.TOKEN_NOTE, (token - self.vocab_size_special) + value = int(value) + else: + type, value = self.TOKEN_SPECIAL, token + value = int(value) + + return [type, value] + + def _convert_token_to_id(self, token, token_type): + if token_type == self.TOKEN_TIME: + return self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity + token + elif token_type == self.TOKEN_VELOCITY: + return self.vocab_size_special + self.vocab_size_note + token + elif token_type == self.TOKEN_NOTE: + return self.vocab_size_special + token + elif token_type == self.TOKEN_SPECIAL: + return token + else: + return -1 + + def relative_batch_tokens_to_midi( + self, + tokens, + beatstep, + beat_offset_idx=None, + bars_per_batch=None, + cutoff_time_idx=None, + ): + """Converts tokens to midi""" + + beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx + notes = None + bars_per_batch = 2 if bars_per_batch is None else bars_per_batch + + N = len(tokens) + for n in range(N): + _tokens = tokens[n] + _start_idx = beat_offset_idx + n * bars_per_batch * 4 + _cutoff_time_idx = cutoff_time_idx + _start_idx + _notes = self.relative_tokens_to_notes( + _tokens, + start_idx=_start_idx, + cutoff_time_idx=_cutoff_time_idx, + ) + + if len(_notes) == 0: + pass + elif notes is None: + notes = _notes + else: + notes = np.concatenate((notes, _notes), axis=0) + + if notes is None: + notes = [] + midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx]) + return midi, notes + + def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): + # decoding If the first token is an arranger + if tokens[0] >= ( + self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity + self.vocab_size_time + ): + tokens = tokens[1:] + + words = [self._convert_id_to_token(token, time_idx_offset=0) for token in tokens] + + if hasattr(start_idx, "item"): + """if numpy or torch tensor""" + start_idx = start_idx.item() + + current_idx = start_idx + current_velocity = 0 + note_onsets_ready = [None for i in range(self.vocab_size_note + 1)] + notes = [] + for type, number in words: + if type == self.TOKEN_SPECIAL: + if number == self.EOS: + break + elif type == self.TOKEN_TIME: + current_idx += number + if cutoff_time_idx is not None: + current_idx = min(current_idx, cutoff_time_idx) + + elif type == self.TOKEN_VELOCITY: + current_velocity = number + elif type == self.TOKEN_NOTE: + pitch = number + if current_velocity == 0: + # note_offset + if note_onsets_ready[pitch] is None: + # offset without onset + pass + else: + onset_idx = note_onsets_ready[pitch] + if onset_idx >= current_idx: + # No time shift after previous note_on + pass + else: + offset_idx = current_idx + notes.append([onset_idx, offset_idx, pitch, self.DEFAULT_VELOCITY]) + note_onsets_ready[pitch] = None + else: + # note_on + if note_onsets_ready[pitch] is None: + note_onsets_ready[pitch] = current_idx + else: + # note-on already exists + onset_idx = note_onsets_ready[pitch] + if onset_idx >= current_idx: + # No time shift after previous note_on + pass + else: + offset_idx = current_idx + notes.append([onset_idx, offset_idx, pitch, self.DEFAULT_VELOCITY]) + note_onsets_ready[pitch] = current_idx + else: + raise ValueError + + for pitch, note_on in enumerate(note_onsets_ready): + # force offset if no offset for each pitch + if note_on is not None: + if cutoff_time_idx is None: + cutoff = note_on + 1 + else: + cutoff = max(cutoff_time_idx, note_on + 1) + + offset_idx = max(current_idx, cutoff) + notes.append([note_on, offset_idx, pitch, self.DEFAULT_VELOCITY]) + + if len(notes) == 0: + return [] + else: + notes = np.array(notes) + note_order = notes[:, 0] * 128 + notes[:, 1] + notes = notes[note_order.argsort()] + return notes + + def notes_to_midi(self, notes, beatstep, offset_sec=None): + """Converts notes to midi""" + + new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0) + new_inst = pretty_midi.Instrument(program=0) + new_notes = [] + if offset_sec is None: + offset_sec = 0.0 + + for onset_idx, offset_idx, pitch, velocity in notes: + new_note = pretty_midi.Note( + velocity=velocity, + pitch=pitch, + start=beatstep[onset_idx] - offset_sec, + end=beatstep[offset_idx] - offset_sec, + ) + new_notes.append(new_note) + new_inst.notes = new_notes + new_pm.instruments.append(new_inst) + new_pm.remove_invalid_notes() + return new_pm + + def get_stereo(self, pop_y, midi_y, pop_scale=0.99): + """Generates stereo audio using `pop audio(`pop_y`)` and `generated midi audio(`midi_y`)`""" + + if len(pop_y) > len(midi_y): + midi_y = np.pad(midi_y, (0, len(pop_y) - len(midi_y))) + elif len(pop_y) < len(midi_y): + pop_y = np.pad(pop_y, (0, -len(pop_y) + len(midi_y))) + stereo = np.stack((midi_y, pop_y * pop_scale)) + return stereo + + def _to_np(self, tensor): + """Converts pytorch tensor to np.ndarray.""" + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.cpu().numpy() + else: + raise ValueError("dtype not understood! Please use either torch.Tensor or np.ndarray") + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Args: + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + tokenizer_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["tokenizer_file"] + ) + with open(tokenizer_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.tokenizer_config)) + + return (tokenizer_file,) + + def __call__( + self, + relative_tokens: Union[np.ndarray, torch.Tensor], + beatsteps: Union[np.ndarray, torch.Tensor], + ext_beatstep: Union[np.ndarray, torch.Tensor], + raw_audio: Union[np.ndarray, List[float], List[np.ndarray]], + sampling_rate: int, + mix_sampling_rate=None, + save_path: str = None, + audio_file_name: str = None, + save_midi: bool = False, + save_mix: bool = False, + click_amp: float = 0.2, + stereo_amp: float = 0.5, + add_click: bool = False, + ): + r""" + This is the `__call__` method for `Pop2PianoTokenizer`. It converts the tokens generated by the transformer to + midi_tokens and also saves the `"generated midi audio"` and `"stereo-mix"`. + + Args: + relative_tokens ([`~utils.TensorType`]): + Output of `Pop2PianoConditionalGeneration` model. + beatsteps ([`~utils.TensorType`]): + beatsteps returned by `Pop2PianoFeatureExtractor.__call__` + ext_beatstep ([`~utils.TensorType`]): + ext_beatstep returned by `Pop2PianoFeatureExtractor.__call__` + raw_audio (`np.ndarray`, `List`): + Denotes the raw_audio. + sampling_rate (`int`): + Denotes the Sampling Rate of `raw_audio`. + mix_sampling_rate (`int`, *optional*): + Denotes the Sampling Rate for `stereo-mix`. + audio_file_name (`str`, *optional*): + Name of the file to be saved. + save_path (`str`, *optional*): + Path where the `stereo-mix` and `midi-audio` is to be saved. + save_midi (`bool`, *optional*): + Whether to save `midi-audio` or not. + save_mix (`bool`, *optional*): + Whether to save `stereo-mix` or not. + add_click (`bool`, *optional*, defaults to `False`): + Constructs a `"click track"`. + click_amp (`float`, *optional*, defaults to 0.2): + Amplitude for `"click track"`. + Returns: + `pretty_midi.pretty_midi.PrettyMIDI` : returns pretty_midi object. + """ + + relative_tokens = self._to_np(relative_tokens) + beatsteps = self._to_np(beatsteps) + ext_beatstep = self._to_np(ext_beatstep) + + if (save_midi or save_mix) and save_path is None: + raise ValueError("If you want to save any mix or midi file then you must define save_path.") + + if save_path and (not save_midi and not save_mix): + raise ValueError( + "You are setting save_path but not saving anything, use save_midi=True to " + "save the midi file and use save_mix to save the mix file or do both!" + ) + + mix_sampling_rate = sampling_rate if mix_sampling_rate is None else mix_sampling_rate + + if save_path is not None: + if os.path.isdir(save_path): + midi_path = os.path.join(save_path, f"midi_output_{audio_file_name}.mid") + mix_path = os.path.join(save_path, f"mix_output_{audio_file_name}.wav") + else: + raise ValueError(f"Is {save_path} a directory?") + + pm, notes = self.relative_batch_tokens_to_midi( + tokens=relative_tokens, + beatstep=ext_beatstep, + bars_per_batch=self.n_bars, + cutoff_time_idx=(self.n_bars + 1) * 4, + ) + for n in pm.instruments[0].notes: + n.start += beatsteps[0] + n.end += beatsteps[0] + + if save_midi: + pm.write(midi_path) + print(f"midi file saved at {midi_path}!") + + if save_mix: + if mix_sampling_rate != sampling_rate: + raw_audio = librosa.core.resample(raw_audio, orig_sr=sampling_rate, target_sr=mix_sampling_rate) + sampling_rate = mix_sampling_rate + if add_click: + clicks = librosa.clicks(times=beatsteps, sr=sampling_rate, length=len(raw_audio)) * click_amp + raw_audio = raw_audio + clicks + pm_raw_audio = pm.fluidsynth(sampling_rate) + stereo = self.get_stereo(raw_audio, pm_raw_audio, pop_scale=stereo_amp) + + sf.write( + file=mix_path, + data=stereo.T, + samplerate=sampling_rate, + format="wav", + ) + print(f"stereo-mix file saved at {mix_path}!") + + return pm diff --git a/src/transformers/utils/dummy_music_objects.py b/src/transformers/utils/dummy_music_objects.py index bdf88958551d12..89052be47c1d32 100644 --- a/src/transformers/utils/dummy_music_objects.py +++ b/src/transformers/utils/dummy_music_objects.py @@ -7,3 +7,10 @@ class Pop2PianoFeatureExtractor(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["music"]) + + +class Pop2PianoTokenizer(metaclass=DummyObject): + _backends = ["music"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["music"]) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 196fc619861ad8..24f2057f26755b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1018,10 +1018,10 @@ def is_cython_available(): # docstyle-ignore MUSIC_IMPORT_ERROR = """ -{0} requires these libraries - pretty_midi, soundfile, essentia, librosa, scipy, torchaudio. But at least +{0} requires these libraries - pretty_midi, soundfile, essentia, librosa, scipy, torchaudio. But at least one of them were not found in your environment. You can install them with pip: `pip install pretty-midi==0.2.9 soundfile essentia==2.1b6.dev609 librosa scipy torchaudio` -Please note that you may need to restart your runtime after installation. +Please note that you may need to restart your runtime after installation. """ DECORD_IMPORT_ERROR = """ diff --git a/tests/models/pop2piano/test_feature_extraction_pop2piano.py b/tests/models/pop2piano/test_feature_extraction_pop2piano.py index c45496aefa4262..259722765d211f 100644 --- a/tests/models/pop2piano/test_feature_extraction_pop2piano.py +++ b/tests/models/pop2piano/test_feature_extraction_pop2piano.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import tempfile import unittest @@ -21,9 +20,9 @@ import numpy as np from datasets import load_dataset -from transformers import is_speech_available from transformers.testing_utils import ( check_json_file_has_correct_format, + is_torchaudio_available, require_essentia, require_librosa, require_pretty_midi, @@ -43,7 +42,7 @@ requirements = ( - is_speech_available() + is_torchaudio_available() and is_torch_available() and is_essentia_available() and is_scipy_available() @@ -52,10 +51,10 @@ ) if requirements: - from transformers import Pop2PianoFeatureExtractor -if is_torch_available(): import torch + from transformers import Pop2PianoFeatureExtractor + @require_torch @require_essentia diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 31d326a57ec0d6..3ffc041fa8ad31 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -20,15 +20,15 @@ from transformers import Pop2PianoConfig from transformers.feature_extraction_utils import BatchFeature -from transformers.testing_utils import require_torch, require_torchaudio, slow, torch_device -from transformers.utils import is_torch_available, is_torchaudio_available +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor -if is_torch_available() and is_torchaudio_available(): +if is_torch_available(): import torch from transformers import Pop2PianoForConditionalGeneration @@ -36,7 +36,6 @@ @require_torch -@require_torchaudio class Pop2PianoModelTester: def __init__( self, @@ -162,7 +161,7 @@ def check_prepare_lm_labels_via_shift_left( # make sure that lm_labels are correctly padded from the right lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) - # add casaul pad token mask + # add causal pad token mask triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() lm_labels.masked_fill_(triangular_mask, self.pad_token_id) decoder_input_ids = model._shift_right(lm_labels) @@ -504,7 +503,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -@require_torchaudio class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Pop2PianoForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = () @@ -635,7 +633,6 @@ def test_generate_with_past_key_values(self): @require_torch -@require_torchaudio class Pop2PianoModelIntegrationTests(unittest.TestCase): @slow def test_mel_conditioner_integration(self): diff --git a/tests/models/pop2piano/test_tokenization_pop2piano.py b/tests/models/pop2piano/test_tokenization_pop2piano.py new file mode 100644 index 00000000000000..adc669e93b3ec7 --- /dev/null +++ b/tests/models/pop2piano/test_tokenization_pop2piano.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + + +from datasets import load_dataset + +from transformers.testing_utils import ( + is_essentia_available, + is_librosa_available, + is_pretty_midi_available, + is_scipy_available, + is_soundfile_availble, + is_torch_available, + is_torchaudio_available, + require_essentia, + require_librosa, + require_pretty_midi, + require_scipy, + require_soundfile, + require_torch, + slow, +) + + +if is_torch_available(): + import torch + + from transformers import Pop2PianoForConditionalGeneration + +requirements = ( + is_torch_available() + and is_torchaudio_available() + and is_essentia_available() + and is_scipy_available() + and is_librosa_available() + and is_soundfile_availble() + and is_pretty_midi_available() +) +if requirements: + import pretty_midi + from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer + + +@require_torch +@require_librosa +@require_soundfile +@require_pretty_midi +class Pop2PianoTokenizerTest(unittest.TestCase): + def test_call(self): + tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev") + input = { + "relative_tokens": torch.ones([120, 96]), + "beatsteps": torch.ones( + [ + 955, + ] + ), + "ext_beatstep": torch.ones( + [ + 958, + ] + ), + "raw_audio": torch.zeros( + [ + 141301, + ] + ), + "sampling_rate": 44100, + "save_midi": False, + "save_mix": False, + } + + output = tokenizer(**input) + self.assertTrue(isinstance(output, pretty_midi.pretty_midi.PrettyMIDI)) + + def _load_datasamples(self, num_samples): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return [x["array"] for x in speech_samples], [x["sampling_rate"] for x in speech_samples] + + @slow + @require_scipy + @require_essentia + @require_librosa + def test_midi_saving(self): + tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev") + feaure_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev") + model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev") + + input_speech, sampling_rate = self._load_datasamples(1) + fe_outputs = feaure_extractor(input_speech, audio_sr=sampling_rate[0], return_tensors="pt") + model_outputs = model.generate(fe_outputs) + filename = "tmp-file" + + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer( + relative_tokens=model_outputs, + beatsteps=fe_outputs["beatsteps"], + ext_beatstep=fe_outputs["ext_beatstep"], + raw_audio=input_speech, + sampling_rate=sampling_rate, + mix_sampling_rate=sampling_rate, + save_path=tmpdirname, + audio_file_name=filename, + save_midi=True, + ) + + # check if files are saved there or not + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, f"midi_output_{filename}.mid")))