diff --git a/src/whisper_ctranslate2/transcribe.py b/src/whisper_ctranslate2/transcribe.py index 2225166..b8a4c24 100644 --- a/src/whisper_ctranslate2/transcribe.py +++ b/src/whisper_ctranslate2/transcribe.py @@ -122,13 +122,10 @@ def __init__( download_root=cache_directory, local_files_only=local_files_only, ) + + self.batch_size = batch_size if batched: - if batch_size: - self.batched_model = BatchedInferencePipeline(model=self.model) - else: - self.batched_model = BatchedInferencePipeline( - model=self.model, batch_size=batch_size - ) + self.batched_model = BatchedInferencePipeline(model=self.model) else: self.batched_model = None @@ -150,6 +147,10 @@ def inference( model = self.model vad = options.vad_filter + batch_size = ( + {"batch_size": self.batch_size} if self.batch_size is not None else {} + ) + segments, info = model.transcribe( audio=audio, language=language, @@ -177,6 +178,7 @@ def inference( hallucination_silence_threshold=options.hallucination_silence_threshold, vad_filter=vad, vad_parameters=vad_parameters, + **batch_size, ) language_name = LANGUAGES[info.language].title()