From f5d6d31c82e58013380829b92c377147778f5a69 Mon Sep 17 00:00:00 2001 From: Jonas Haag Date: Thu, 3 Dec 2020 15:55:13 +0100 Subject: [PATCH 1/2] Add --device support to asteroid-infer --- asteroid/scripts/asteroid_cli.py | 20 +++++++++++++++++--- tests/cli_test.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/asteroid/scripts/asteroid_cli.py b/asteroid/scripts/asteroid_cli.py index bb6ee6c47..eb1310e6c 100644 --- a/asteroid/scripts/asteroid_cli.py +++ b/asteroid/scripts/asteroid_cli.py @@ -87,7 +87,7 @@ def upload(): ) -def infer(): +def infer(argv=None): """CLI function to run pretrained model inference on wav files.""" parser = argparse.ArgumentParser() parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.") @@ -140,7 +140,20 @@ def infer(): parser.add_argument( "-o", "--output-dir", default=None, type=str, help="Output directory to save files." ) - args = parser.parse_args() + parser.add_argument( + "-d", + "--device", + default=None, + type=str, + help="Device to run the model on, eg. 'cuda:0'." + "Defaults to 'cuda' if CUDA is available, else 'cpu'.", + ) + args = parser.parse_args(argv) + + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device model = BaseModel.from_pretrained(pretrained_model_conf_or_path=args.url_or_path) if args.ola_window is not None: @@ -152,8 +165,9 @@ def infer(): window=args.ola_window_type, reorder_chunks=not args.ola_no_reorder, ) - file_list = _process_files_as_list(args.files) + model = model.to(device) + file_list = _process_files_as_list(args.files) for f in file_list: separate( model, diff --git a/tests/cli_test.py b/tests/cli_test.py index 849069f08..3bfdeaaa9 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -1,4 +1,4 @@ -from asteroid.scripts import asteroid_versions +from asteroid.scripts import asteroid_versions, asteroid_cli def test_asteroid_versions(): @@ -15,3 +15,31 @@ def test_print_versions(): def test_asteroid_versions_without_git(monkeypatch): monkeypatch.setenv("PATH", "") asteroid_versions.asteroid_versions() + + +def test_infer_device(monkeypatch): + """Test that inference is performed on the PyTorch device given by '--device'. + + We can't properly test this in environments with only CPU device available. + As an approximation we test that the '.to()' method of the model is called + with the device given by '--device'. + """ + # We can't use a real model to test this because calling .to() with a fake device + # on a real model will fail. + class FakeModel: + def to(self, device): + self.device = device + + fake_model = FakeModel() + + # Monkeypatch 'from_pretrained' to load our fake model. + from asteroid.models import BaseModel + + monkeypatch.setattr(BaseModel, "from_pretrained", lambda *args, **kwargs: fake_model) + + # Note that this will issue a warning about the missing file. + asteroid_cli.infer( + ["--device", "cuda:42", "somemodel", "--files", "file_that_does_not_exist.wav"] + ) + + assert fake_model.device == "cuda:42" From 6f117daa46a1d9e13a7e9301837dcce3893ddc75 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 5 Jan 2021 21:21:41 +0100 Subject: [PATCH 2/2] Import torch before device check --- asteroid/scripts/asteroid_cli.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asteroid/scripts/asteroid_cli.py b/asteroid/scripts/asteroid_cli.py index eb1310e6c..32ea2358c 100644 --- a/asteroid/scripts/asteroid_cli.py +++ b/asteroid/scripts/asteroid_cli.py @@ -1,5 +1,6 @@ import os import argparse +import torch import yaml import itertools import glob