Skip to content

Commit

Permalink
support whisper long-form generation (#469)
Browse files Browse the repository at this point in the history
* fix long form asr accuracy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test input pad issue

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: chen, suyue <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 7aee7e4 commit daec680
Showing 1 changed file with 104 additions and 18 deletions.
122 changes: 104 additions & 18 deletions comps/asr/whisper/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class WhisperModel:
"""Convert audio to text."""

def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu"):
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192):
if device == "hpu":
# Explicitly link HPU with Torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
Expand All @@ -31,12 +31,11 @@ def __init__(self, model_name_or_path="openai/whisper-small", language="english"
self.model.eval()

self.language = language
self.hpu_max_len = hpu_max_len

if device == "hpu":
# do hpu graph warmup with a long enough input audio
# whisper has a receptive field of 30 seconds
# here we select a relatively long audio (~15 sec) to quickly warmup
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_30s_audio.wav")

def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
Expand All @@ -59,11 +58,54 @@ def _warmup_whisper_hpu_graph(self, url):
print("[ASR] warmup...")
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
waveform = self._audiosegment_to_librosawav(waveform)
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
_ = self.model.generate(inputs, language="chinese")

try:
processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)
except RuntimeError as e:
if "Padding size should be less than" in str(e):
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
raise e

if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)

_ = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)

def audio2text(self, audio_path):
"""Convert audio to text.
Expand All @@ -80,11 +122,52 @@ def audio2text(self, audio_path):
audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
waveform = audio_dataset[0]["audio"]["array"]

# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
predicted_ids = self.model.generate(inputs, language=self.language)
try:
processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)
except RuntimeError as e:
if "Padding size should be less than" in str(e):
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
raise e
if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
elif self.device == "hpu":
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)

predicted_ids = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)
# pylint: disable=E1101
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
if self.language in ["chinese", "mandarin"]:
Expand All @@ -96,20 +179,23 @@ def audio2text(self, audio_path):


if __name__ == "__main__":
asr = WhisperModel(language="english")
asr = WhisperModel(model_name_or_path="openai/whisper-small", language="english", device="cpu")

# Test multilanguage asr
asr.language = "chinese"
urllib.request.urlretrieve(
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
"sample.wav",
)
asr.language = "chinese"
text = asr.audio2text("sample.wav")

asr.language = "english"
urllib.request.urlretrieve(
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
"sample.wav",
)
text = asr.audio2text("sample.wav")

os.remove("sample.wav")
for i in [5, 10, 30, 60]:
urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav")
text = asr.audio2text("sample.wav")

0 comments on commit daec680

Please sign in to comment.