From 6ae0dd44f1a45f246f53df510c22d13e8c6b9923 Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:23:07 +0800 Subject: [PATCH] support whisper long-form generation (#469) * fix long form asr accuracy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test input pad issue --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue --- comps/asr/whisper/whisper_model.py | 122 ++++++++++++++++++++++++----- 1 file changed, 104 insertions(+), 18 deletions(-) diff --git a/comps/asr/whisper/whisper_model.py b/comps/asr/whisper/whisper_model.py index 0af9ebfcb3..c5f16e1121 100644 --- a/comps/asr/whisper/whisper_model.py +++ b/comps/asr/whisper/whisper_model.py @@ -16,7 +16,7 @@ class WhisperModel: """Convert audio to text.""" - def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu"): + def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192): if device == "hpu": # Explicitly link HPU with Torch from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -31,12 +31,11 @@ def __init__(self, model_name_or_path="openai/whisper-small", language="english" self.model.eval() self.language = language + self.hpu_max_len = hpu_max_len if device == "hpu": - # do hpu graph warmup with a long enough input audio - # whisper has a receptive field of 30 seconds - # here we select a relatively long audio (~15 sec) to quickly warmup - self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav") + self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav") + self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_30s_audio.wav") def _audiosegment_to_librosawav(self, audiosegment): # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples @@ -59,11 +58,54 @@ def _warmup_whisper_hpu_graph(self, url): print("[ASR] warmup...") waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000) waveform = self._audiosegment_to_librosawav(waveform) - # pylint: disable=E1101 - inputs = self.processor.feature_extractor( - waveform, return_tensors="pt", sampling_rate=16_000 - ).input_features.to(self.device) - _ = self.model.generate(inputs, language="chinese") + + try: + processed_inputs = self.processor( + waveform, + return_tensors="pt", + truncation=False, + padding="longest", + return_attention_mask=True, + sampling_rate=16000, + ) + except RuntimeError as e: + if "Padding size should be less than" in str(e): + # short-form + processed_inputs = self.processor( + waveform, + return_tensors="pt", + sampling_rate=16000, + ) + else: + raise e + + if processed_inputs.input_features.shape[-1] < 3000: + # short-form + processed_inputs = self.processor( + waveform, + return_tensors="pt", + sampling_rate=16000, + ) + else: + processed_inputs["input_features"] = torch.nn.functional.pad( + processed_inputs.input_features, + (0, self.hpu_max_len - processed_inputs.input_features.size(-1)), + value=-1.5, + ) + processed_inputs["attention_mask"] = torch.nn.functional.pad( + processed_inputs.attention_mask, + (0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)), + value=0, + ) + + _ = self.model.generate( + **( + processed_inputs.to( + self.device, + ) + ), + language=self.language, + ) def audio2text(self, audio_path): """Convert audio to text. @@ -80,11 +122,52 @@ def audio2text(self, audio_path): audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000)) waveform = audio_dataset[0]["audio"]["array"] - # pylint: disable=E1101 - inputs = self.processor.feature_extractor( - waveform, return_tensors="pt", sampling_rate=16_000 - ).input_features.to(self.device) - predicted_ids = self.model.generate(inputs, language=self.language) + try: + processed_inputs = self.processor( + waveform, + return_tensors="pt", + truncation=False, + padding="longest", + return_attention_mask=True, + sampling_rate=16000, + ) + except RuntimeError as e: + if "Padding size should be less than" in str(e): + # short-form + processed_inputs = self.processor( + waveform, + return_tensors="pt", + sampling_rate=16000, + ) + else: + raise e + if processed_inputs.input_features.shape[-1] < 3000: + # short-form + processed_inputs = self.processor( + waveform, + return_tensors="pt", + sampling_rate=16000, + ) + elif self.device == "hpu": + processed_inputs["input_features"] = torch.nn.functional.pad( + processed_inputs.input_features, + (0, self.hpu_max_len - processed_inputs.input_features.size(-1)), + value=-1.5, + ) + processed_inputs["attention_mask"] = torch.nn.functional.pad( + processed_inputs.attention_mask, + (0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)), + value=0, + ) + + predicted_ids = self.model.generate( + **( + processed_inputs.to( + self.device, + ) + ), + language=self.language, + ) # pylint: disable=E1101 result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0] if self.language in ["chinese", "mandarin"]: @@ -96,20 +179,23 @@ def audio2text(self, audio_path): if __name__ == "__main__": - asr = WhisperModel(language="english") + asr = WhisperModel(model_name_or_path="openai/whisper-small", language="english", device="cpu") # Test multilanguage asr + asr.language = "chinese" urllib.request.urlretrieve( "https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav", "sample.wav", ) - asr.language = "chinese" text = asr.audio2text("sample.wav") + asr.language = "english" urllib.request.urlretrieve( "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav", "sample.wav", ) text = asr.audio2text("sample.wav") - os.remove("sample.wav") + for i in [5, 10, 30, 60]: + urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav") + text = asr.audio2text("sample.wav")