diff --git a/buzz/transcriber/recording_transcriber.py b/buzz/transcriber/recording_transcriber.py index 962ed64d3..64a9c400d 100644 --- a/buzz/transcriber/recording_transcriber.py +++ b/buzz/transcriber/recording_transcriber.py @@ -81,7 +81,7 @@ def start(self): logging.debug("Will use whisper API on %s, %s", custom_openai_base_url, self.whisper_api_model) else: # ModelType.HUGGING_FACE - model = transformers_whisper.load_model(model_path) + model = TransformersWhisper(model_path) initial_prompt = self.transcription_options.initial_prompt diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py index abb34d862..0a8d68c21 100644 --- a/buzz/transcriber/whisper_file_transcriber.py +++ b/buzz/transcriber/whisper_file_transcriber.py @@ -12,9 +12,9 @@ import tqdm from PyQt6.QtCore import QObject -from buzz import transformers_whisper from buzz.conn import pipe_stderr from buzz.model_loader import ModelType +from buzz.transformers_whisper import TransformersWhisper from buzz.transcriber.file_transcriber import FileTranscriber from buzz.transcriber.transcriber import FileTranscriptionTask, Segment @@ -87,7 +87,10 @@ def transcribe_whisper( ) -> None: with pipe_stderr(stderr_conn): if task.transcription_options.model.model_type == ModelType.HUGGING_FACE: + # TODO Find a way to emmit real progress + sys.stderr.write("0%\n") segments = cls.transcribe_hugging_face(task) + sys.stderr.write("100%\n") elif ( task.transcription_options.model.model_type == ModelType.FASTER_WHISPER ): @@ -105,7 +108,7 @@ def transcribe_whisper( @classmethod def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]: - model = transformers_whisper.load_model(task.model_path) + model = TransformersWhisper(task.model_path) language = ( task.transcription_options.language if task.transcription_options.language is not None @@ -115,7 +118,6 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]: audio=task.file_path, language=language, task=task.transcription_options.task.value, - verbose=False, ) return [ Segment( diff --git a/buzz/transformers_whisper.py b/buzz/transformers_whisper.py index f26da17b2..270c84405 100644 --- a/buzz/transformers_whisper.py +++ b/buzz/transformers_whisper.py @@ -1,98 +1,61 @@ -import sys -import logging from typing import Optional, Union import numpy as np -from tqdm import tqdm - -import whisper import torch -from transformers import WhisperProcessor, WhisperForConditionalGeneration - -def cuda_is_viable(min_vram_gb=10): - if not torch.cuda.is_available(): - return False - - total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # Convert bytes to GB - if total_memory < min_vram_gb: - return False - - return True - - -def load_model(model_name_or_path: str): - processor = WhisperProcessor.from_pretrained(model_name_or_path) - model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path) - - if cuda_is_viable(): - logging.debug("CUDA is available and has enough VRAM, moving model to GPU.") - model.to("cuda") - - return TransformersWhisper(processor, model) +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline class TransformersWhisper: def __init__( - self, processor: WhisperProcessor, model: WhisperForConditionalGeneration + self, model_id: str ): - self.processor = processor - self.model = model - self.SAMPLE_RATE = whisper.audio.SAMPLE_RATE - self.N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES + self.model_id = model_id - # Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and - # timestamps are available. See: https://github.com/huggingface/transformers/issues/19887, - # https://github.com/huggingface/transformers/pull/20620. def transcribe( self, audio: Union[str, np.ndarray], language: str, task: str, - verbose: Optional[bool] = None, ): - if isinstance(audio, str): - audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE) - self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids( - task=task, language=language + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + + model = AutoModelForSpeechSeq2Seq.from_pretrained( + self.model_id, torch_dtype=torch_dtype, use_safetensors=True ) - segments = [] - all_predicted_ids = [] + model.generation_config.language = language + model.to(device) - num_samples = audio.size - seek = 0 - with tqdm( - total=num_samples, unit="samples", disable=verbose is not False - ) as progress_bar: - while seek < num_samples: - chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK] - input_features = self.processor( - chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE - ).input_features.to(self.model.device) - predicted_ids = self.model.generate(input_features) - all_predicted_ids.extend(predicted_ids) - text: str = self.processor.batch_decode( - predicted_ids, skip_special_tokens=True - )[0] - if text.strip() != "": - segments.append( - { - "start": seek / self.SAMPLE_RATE, - "end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - / self.SAMPLE_RATE, - "text": text, - } - ) + processor = AutoProcessor.from_pretrained(self.model_id) + + pipe = pipeline( + "automatic-speech-recognition", + generate_kwargs={"language": language, "task": task}, + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + chunk_length_s=30, + torch_dtype=torch_dtype, + device=device, + ) - progress_bar.update( - min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek - ) - seek += self.N_SAMPLES_IN_CHUNK + transcript = pipe(audio, return_timestamps=True) + + segments = [] + for chunk in transcript['chunks']: + start, end = chunk['timestamp'] + text = chunk['text'] + segments.append({ + "start": start, + "end": end, + "text": text, + "translation": "" + }) return { - "text": self.processor.batch_decode( - all_predicted_ids, skip_special_tokens=True - )[0], + "text": transcript['text'], "segments": segments, } + diff --git a/tests/transformers_whisper_test.py b/tests/transformers_whisper_test.py index 932267556..28da84b7c 100644 --- a/tests/transformers_whisper_test.py +++ b/tests/transformers_whisper_test.py @@ -1,7 +1,7 @@ import platform import pytest -from buzz.transformers_whisper import load_model +from buzz.transformers_whisper import TransformersWhisper from tests.audio import test_audio_path @@ -11,7 +11,7 @@ ) class TestTransformersWhisper: def test_should_transcribe(self): - model = load_model("openai/whisper-tiny") + model = TransformersWhisper("openai/whisper-tiny") result = model.transcribe( audio=test_audio_path, language="fr", task="transcribe" )