Skip to content

Commit

Permalink
Refactor transcribers class (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Jun 26, 2023
1 parent fa08e53 commit 2dc0797
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 131 deletions.
3 changes: 2 additions & 1 deletion buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
Task,
TranscriptionOptions,
FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL,
FileTranscriptionTask, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE, LANGUAGES)
from .recording_transcriber import RecordingTranscriber
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
from .widgets.line_edit import LineEdit
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
Expand Down
139 changes: 139 additions & 0 deletions buzz/recording_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import datetime
import logging
import threading
from typing import Optional

import numpy as np
import sounddevice
import whisper
from PyQt6.QtCore import QObject, pyqtSignal
from sounddevice import PortAudioError

from buzz import transformers_whisper
from buzz.model_loader import ModelType
from buzz.transcriber import TranscriptionOptions, WhisperCpp, whisper_cpp_params
from buzz.transformers_whisper import TransformersWhisper


class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10

def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = sample_rate
self.model_path = model_path
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()

def start(self):
model_path = self.model_path

if self.transcription_options.model.model_type == ModelType.WHISPER:
model = whisper.load_model(model_path)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
model = WhisperCpp(model_path)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)

initial_prompt = self.transcription_options.initial_prompt

logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
self.transcription_options, model_path, self.sample_rate, self.input_device_index)

self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()

logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()

if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)

next_text: str = result.get('text')

# Update initial prompt between successive recording chunks
initial_prompt += next_text

logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return

self.finished.emit()

@staticmethod
def get_device_sample_rate(device_id: Optional[int]) -> int:
"""Returns the sample rate to be used for recording. It uses the default sample rate
provided by Whisper if the microphone supports it, or else it uses the device's default
sample rate.
"""
whisper_sample_rate = whisper.audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate)
return whisper_sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get('default_samplerate', whisper_sample_rate))
return whisper_sample_rate

def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
# Try to enqueue the next block. If the queue is already full, drop the block.
chunk: np.ndarray = in_data.ravel()
with self.mutex:
if self.queue.size < self.max_queue_size:
self.queue = np.append(self.queue, chunk)

@staticmethod
def amplitude(arr: np.ndarray):
return (abs(max(arr)) + abs(min(arr))) / 2

def stop_recording(self):
self.is_running = False
128 changes: 0 additions & 128 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import sys
import tempfile
import threading
from abc import abstractmethod
from dataclasses import dataclass, field
from multiprocessing.connection import Connection
Expand All @@ -19,18 +18,15 @@
import ffmpeg
import numpy as np
import openai
import sounddevice
import stable_whisper
import tqdm
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from sounddevice import PortAudioError
from whisper import tokenizer

from . import transformers_whisper
from .conn import pipe_stderr
from .model_loader import TranscriptionModel, ModelType
from .transformers_whisper import TransformersWhisper

# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
Expand Down Expand Up @@ -101,130 +97,6 @@ class Status(enum.Enum):
completed_at: Optional[datetime.datetime] = None


class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
MAX_QUEUE_SIZE = 10

def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
self.input_device_index = input_device_index
self.sample_rate = sample_rate
self.model_path = model_path
self.n_batch_samples = 5 * self.sample_rate # every 5 seconds
# pause queueing if more than 3 batches behind
self.max_queue_size = 3 * self.n_batch_samples
self.queue = np.ndarray([], dtype=np.float32)
self.mutex = threading.Lock()

def start(self):
model_path = self.model_path

if self.transcription_options.model.model_type == ModelType.WHISPER:
model = whisper.load_model(model_path)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
model = WhisperCpp(model_path)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)

initial_prompt = self.transcription_options.initial_prompt

logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
self.transcription_options, model_path, self.sample_rate, self.input_device_index)

self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
self.mutex.release()

logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
time_started = datetime.datetime.now()

if self.transcription_options.model.model_type == ModelType.WHISPER:
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)

next_text: str = result.get('text')

# Update initial prompt between successive recording chunks
initial_prompt += next_text

logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
self.transcription.emit(next_text)
else:
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
return

self.finished.emit()

@staticmethod
def get_device_sample_rate(device_id: Optional[int]) -> int:
"""Returns the sample rate to be used for recording. It uses the default sample rate
provided by Whisper if the microphone supports it, or else it uses the device's default
sample rate.
"""
whisper_sample_rate = whisper.audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate)
return whisper_sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get('default_samplerate', whisper_sample_rate))
return whisper_sample_rate

def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
# Try to enqueue the next block. If the queue is already full, drop the block.
chunk: np.ndarray = in_data.ravel()
with self.mutex:
if self.queue.size < self.max_queue_size:
self.queue = np.append(self.queue, chunk)

@staticmethod
def amplitude(arr: np.ndarray):
return (abs(max(arr)) + abs(min(arr))) / 2

def stop_recording(self):
self.is_running = False


class OutputFormat(enum.Enum):
TXT = 'txt'
SRT = 'srt'
Expand Down
4 changes: 2 additions & 2 deletions tests/transcriber_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from pytestqt.qtbot import QtBot

from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber,
Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path, to_timestamp,
whisper_cpp_params, write_output, TranscriptionOptions)
from buzz.recording_transcriber import RecordingTranscriber
from tests.mock_sounddevice import MockInputStream
from tests.model_loader import get_model_path

Expand Down

0 comments on commit 2dc0797

Please sign in to comment.