diff --git a/asteroid/scripts/asteroid_cli.py b/asteroid/scripts/asteroid_cli.py index bb6ee6c47..32ea2358c 100644 --- a/asteroid/scripts/asteroid_cli.py +++ b/asteroid/scripts/asteroid_cli.py @@ -1,5 +1,6 @@ import os import argparse +import torch import yaml import itertools import glob @@ -87,7 +88,7 @@ def upload(): ) -def infer(): +def infer(argv=None): """CLI function to run pretrained model inference on wav files.""" parser = argparse.ArgumentParser() parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.") @@ -140,7 +141,20 @@ def infer(): parser.add_argument( "-o", "--output-dir", default=None, type=str, help="Output directory to save files." ) - args = parser.parse_args() + parser.add_argument( + "-d", + "--device", + default=None, + type=str, + help="Device to run the model on, eg. 'cuda:0'." + "Defaults to 'cuda' if CUDA is available, else 'cpu'.", + ) + args = parser.parse_args(argv) + + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device model = BaseModel.from_pretrained(pretrained_model_conf_or_path=args.url_or_path) if args.ola_window is not None: @@ -152,8 +166,9 @@ def infer(): window=args.ola_window_type, reorder_chunks=not args.ola_no_reorder, ) - file_list = _process_files_as_list(args.files) + model = model.to(device) + file_list = _process_files_as_list(args.files) for f in file_list: separate( model, diff --git a/tests/cli_test.py b/tests/cli_test.py index 849069f08..3bfdeaaa9 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -1,4 +1,4 @@ -from asteroid.scripts import asteroid_versions +from asteroid.scripts import asteroid_versions, asteroid_cli def test_asteroid_versions(): @@ -15,3 +15,31 @@ def test_print_versions(): def test_asteroid_versions_without_git(monkeypatch): monkeypatch.setenv("PATH", "") asteroid_versions.asteroid_versions() + + +def test_infer_device(monkeypatch): + """Test that inference is performed on the PyTorch device given by '--device'. + + We can't properly test this in environments with only CPU device available. + As an approximation we test that the '.to()' method of the model is called + with the device given by '--device'. + """ + # We can't use a real model to test this because calling .to() with a fake device + # on a real model will fail. + class FakeModel: + def to(self, device): + self.device = device + + fake_model = FakeModel() + + # Monkeypatch 'from_pretrained' to load our fake model. + from asteroid.models import BaseModel + + monkeypatch.setattr(BaseModel, "from_pretrained", lambda *args, **kwargs: fake_model) + + # Note that this will issue a warning about the missing file. + asteroid_cli.infer( + ["--device", "cuda:42", "somemodel", "--files", "file_that_does_not_exist.wav"] + ) + + assert fake_model.device == "cuda:42"