Skip to content

Commit

Permalink
Switching to pipeline for HF whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus committed Jun 24, 2024
1 parent cf340bc commit 6c06673
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 79 deletions.
2 changes: 1 addition & 1 deletion buzz/transcriber/recording_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def start(self):
logging.debug("Will use whisper API on %s, %s",
custom_openai_base_url, self.whisper_api_model)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)
model = TransformersWhisper(model_path)

initial_prompt = self.transcription_options.initial_prompt

Expand Down
8 changes: 5 additions & 3 deletions buzz/transcriber/whisper_file_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import tqdm
from PyQt6.QtCore import QObject

from buzz import transformers_whisper
from buzz.conn import pipe_stderr
from buzz.model_loader import ModelType
from buzz.transformers_whisper import TransformersWhisper
from buzz.transcriber.file_transcriber import FileTranscriber
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment

Expand Down Expand Up @@ -87,7 +87,10 @@ def transcribe_whisper(
) -> None:
with pipe_stderr(stderr_conn):
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
# TODO Find a way to emmit real progress
sys.stderr.write("0%\n")
segments = cls.transcribe_hugging_face(task)
sys.stderr.write("100%\n")
elif (
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
):
Expand All @@ -105,7 +108,7 @@ def transcribe_whisper(

@classmethod
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
model = transformers_whisper.load_model(task.model_path)
model = TransformersWhisper(task.model_path)
language = (
task.transcription_options.language
if task.transcription_options.language is not None
Expand All @@ -115,7 +118,6 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
audio=task.file_path,
language=language,
task=task.transcription_options.task.value,
verbose=False,
)
return [
Segment(
Expand Down
109 changes: 36 additions & 73 deletions buzz/transformers_whisper.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,61 @@
import sys
import logging
from typing import Optional, Union

import numpy as np
from tqdm import tqdm

import whisper
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration

def cuda_is_viable(min_vram_gb=10):
if not torch.cuda.is_available():
return False

total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # Convert bytes to GB
if total_memory < min_vram_gb:
return False

return True


def load_model(model_name_or_path: str):
processor = WhisperProcessor.from_pretrained(model_name_or_path)
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)

if cuda_is_viable():
logging.debug("CUDA is available and has enough VRAM, moving model to GPU.")
model.to("cuda")

return TransformersWhisper(processor, model)
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


class TransformersWhisper:
def __init__(
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration
self, model_id: str
):
self.processor = processor
self.model = model
self.SAMPLE_RATE = whisper.audio.SAMPLE_RATE
self.N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
self.model_id = model_id

# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
# https://github.com/huggingface/transformers/pull/20620.
def transcribe(
self,
audio: Union[str, np.ndarray],
language: str,
task: str,
verbose: Optional[bool] = None,
):
if isinstance(audio, str):
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE)

self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
task=task, language=language
device = 0 if torch.cuda.is_available() else None
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_id, torch_dtype=torch_dtype, use_safetensors=True
)

segments = []
all_predicted_ids = []
model.generation_config.language = language
model.to(device)

num_samples = audio.size
seek = 0
with tqdm(
total=num_samples, unit="samples", disable=verbose is not False
) as progress_bar:
while seek < num_samples:
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK]
input_features = self.processor(
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE
).input_features.to(self.model.device)
predicted_ids = self.model.generate(input_features)
all_predicted_ids.extend(predicted_ids)
text: str = self.processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
if text.strip() != "":
segments.append(
{
"start": seek / self.SAMPLE_RATE,
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples)
/ self.SAMPLE_RATE,
"text": text,
}
)
processor = AutoProcessor.from_pretrained(self.model_id)

pipe = pipeline(
"automatic-speech-recognition",
generate_kwargs={"language": language, "task": task},
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
torch_dtype=torch_dtype,
device=device,
)

progress_bar.update(
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek
)
seek += self.N_SAMPLES_IN_CHUNK
transcript = pipe(audio, return_timestamps=True)

segments = []
for chunk in transcript['chunks']:
start, end = chunk['timestamp']
text = chunk['text']
segments.append({
"start": start,
"end": end,
"text": text,
"translation": ""
})

return {
"text": self.processor.batch_decode(
all_predicted_ids, skip_special_tokens=True
)[0],
"text": transcript['text'],
"segments": segments,
}

4 changes: 2 additions & 2 deletions tests/transformers_whisper_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import platform
import pytest

from buzz.transformers_whisper import load_model
from buzz.transformers_whisper import TransformersWhisper
from tests.audio import test_audio_path


Expand All @@ -11,7 +11,7 @@
)
class TestTransformersWhisper:
def test_should_transcribe(self):
model = load_model("openai/whisper-tiny")
model = TransformersWhisper("openai/whisper-tiny")
result = model.transcribe(
audio=test_audio_path, language="fr", task="transcribe"
)
Expand Down

0 comments on commit 6c06673

Please sign in to comment.