Skip to content

Commit

Permalink
Upgrade to Silero-Vad V5 (SYSTRAN#884)
Browse files Browse the repository at this point in the history
* Fix window_size_samples to 512

* Update SileroVADModel

* Replace ONNX file with V5 version
  • Loading branch information
hoonlight authored Jul 1, 2024
1 parent bced5f0 commit 8d400e9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 26 deletions.
Binary file modified faster_whisper/assets/silero_vad.onnx
Binary file not shown.
39 changes: 13 additions & 26 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import bisect
import functools
import os
import warnings

from typing import List, NamedTuple, Optional

Expand All @@ -25,17 +24,13 @@ class VadOptions(NamedTuple):
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""

threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400


Expand All @@ -61,15 +56,8 @@ def get_speech_timestamps(
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
window_size_samples = 512
speech_pad_ms = vad_options.speech_pad_ms

if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)

sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
Expand All @@ -84,14 +72,14 @@ def get_speech_timestamps(
audio_length_samples = len(audio)

model = get_vad_model()
state = model.get_initial_state(batch_size=1)
state, context = model.get_initial_states(batch_size=1)

speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = model(chunk, state, sampling_rate)
speech_prob, state, context = model(chunk, state, context, sampling_rate)
speech_probs.append(speech_prob)

triggered = False
Expand Down Expand Up @@ -261,12 +249,12 @@ def __init__(self, path):
sess_options=opts,
)

def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
def get_initial_states(self, batch_size: int):
state = np.zeros((2, batch_size, 128), dtype=np.float32)
context = np.zeros((batch_size, 64), dtype=np.float32)
return state, context

def __call__(self, x, state, sr: int):
def __call__(self, x, state, context, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
Expand All @@ -276,16 +264,15 @@ def __call__(self, x, state, sr: int):
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")

h, c = state
x = np.concatenate([context, x], axis=1)

ort_inputs = {
"input": x,
"h": h,
"c": c,
"state": state,
"sr": np.array(sr, dtype="int64"),
}

out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
out, state = self.session.run(None, ort_inputs)
context = x[..., -64:]

return out, state
return out, state, context

0 comments on commit 8d400e9

Please sign in to comment.