Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device= support to asteroid-infer #375

Merged
merged 2 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions asteroid/scripts/asteroid_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def upload():
)


def infer():
def infer(argv=None):
"""CLI function to run pretrained model inference on wav files."""
parser = argparse.ArgumentParser()
parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.")
Expand Down Expand Up @@ -140,7 +140,20 @@ def infer():
parser.add_argument(
"-o", "--output-dir", default=None, type=str, help="Output directory to save files."
)
args = parser.parse_args()
parser.add_argument(
"-d",
"--device",
default=None,
type=str,
help="Device to run the model on, eg. 'cuda:0'."
"Defaults to 'cuda' if CUDA is available, else 'cpu'.",
)
args = parser.parse_args(argv)

if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
mpariente marked this conversation as resolved.
Show resolved Hide resolved
else:
device = args.device

model = BaseModel.from_pretrained(pretrained_model_conf_or_path=args.url_or_path)
if args.ola_window is not None:
Expand All @@ -152,8 +165,9 @@ def infer():
window=args.ola_window_type,
reorder_chunks=not args.ola_no_reorder,
)
file_list = _process_files_as_list(args.files)
model = model.to(device)

file_list = _process_files_as_list(args.files)
for f in file_list:
separate(
model,
Expand Down
30 changes: 29 additions & 1 deletion tests/cli_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asteroid.scripts import asteroid_versions
from asteroid.scripts import asteroid_versions, asteroid_cli


def test_asteroid_versions():
Expand All @@ -15,3 +15,31 @@ def test_print_versions():
def test_asteroid_versions_without_git(monkeypatch):
monkeypatch.setenv("PATH", "")
asteroid_versions.asteroid_versions()


def test_infer_device(monkeypatch):
"""Test that inference is performed on the PyTorch device given by '--device'.

We can't properly test this in environments with only CPU device available.
As an approximation we test that the '.to()' method of the model is called
with the device given by '--device'.
"""
# We can't use a real model to test this because calling .to() with a fake device
# on a real model will fail.
class FakeModel:
def to(self, device):
self.device = device

fake_model = FakeModel()

# Monkeypatch 'from_pretrained' to load our fake model.
from asteroid.models import BaseModel

monkeypatch.setattr(BaseModel, "from_pretrained", lambda *args, **kwargs: fake_model)

# Note that this will issue a warning about the missing file.
asteroid_cli.infer(
["--device", "cuda:42", "somemodel", "--files", "file_that_does_not_exist.wav"]
)

assert fake_model.device == "cuda:42"