Skip to content

Commit

Permalink
Revert "Upgrade to Silero-Vad V5 (SYSTRAN#884)"
Browse files Browse the repository at this point in the history
This reverts commit 8d400e9.
  • Loading branch information
shinlw committed Sep 6, 2024
1 parent 9d38313 commit 7af0b45
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
Binary file modified faster_whisper/assets/silero_vad.onnx
Binary file not shown.
39 changes: 26 additions & 13 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bisect
import functools
import os
import warnings

from typing import List, NamedTuple, Optional

Expand All @@ -24,13 +25,17 @@ class VadOptions(NamedTuple):
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""

threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400


Expand All @@ -56,8 +61,15 @@ def get_speech_timestamps(
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = 512
window_size_samples = vad_options.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms

if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)

sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
Expand All @@ -72,14 +84,14 @@ def get_speech_timestamps(
audio_length_samples = len(audio)

model = get_vad_model()
state, context = model.get_initial_states(batch_size=1)
state = model.get_initial_state(batch_size=1)

speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state, context = model(chunk, state, context, sampling_rate)
speech_prob, state = model(chunk, state, sampling_rate)
speech_probs.append(speech_prob)

triggered = False
Expand Down Expand Up @@ -249,12 +261,12 @@ def __init__(self, path):
sess_options=opts,
)

def get_initial_states(self, batch_size: int):
state = np.zeros((2, batch_size, 128), dtype=np.float32)
context = np.zeros((batch_size, 64), dtype=np.float32)
return state, context
def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c

def __call__(self, x, state, context, sr: int):
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
Expand All @@ -264,15 +276,16 @@ def __call__(self, x, state, context, sr: int):
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")

x = np.concatenate([context, x], axis=1)
h, c = state

ort_inputs = {
"input": x,
"state": state,
"h": h,
"c": c,
"sr": np.array(sr, dtype="int64"),
}

out, state = self.session.run(None, ort_inputs)
context = x[..., -64:]
out, h, c = self.session.run(None, ort_inputs)
state = (h, c)

return out, state, context
return out, state

0 comments on commit 7af0b45

Please sign in to comment.