Skip to content

Commit

Permalink
predownload vad weights
Browse files Browse the repository at this point in the history
  • Loading branch information
villesau committed Oct 4, 2024
1 parent 8067a0c commit 367dd57
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class Predictor(BasePredictor):
def setup(self):
os.environ['TORCH_HOME'] = './weights'
"""Load the model into memory to make running multiple predictions efficient"""
self.model = whisper_timestamped.load_model("weights/whisper/large-v3.pt", device="cuda")

Expand Down
27 changes: 27 additions & 0 deletions scripts/download_vad_weights
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python

import os
import torch

# Set the TORCH_HOME environment variable
os.environ['TORCH_HOME'] = './weights'

# Define and create the cache directory
CACHE_DIR = os.path.join(os.environ['TORCH_HOME'], 'hub')
os.makedirs(CACHE_DIR, exist_ok=True)

# Set the torch hub directory
torch.hub.set_dir(CACHE_DIR)

# Download Silero VAD model
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=True,
onnx=False,
source="github"
)

print(f"Model weights downloaded to: {CACHE_DIR}")
print("\nTo use with whisper-timestamped, set TORCH_HOME:")
print(f"export TORCH_HOME={os.path.abspath('./weights')}")

0 comments on commit 367dd57

Please sign in to comment.