diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 701ff0eb..12b3d42e 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -51,21 +51,13 @@ click.rich_click.SHOW_ARGUMENTS = True -class _SharedParams(click.RichCommand): - """Options shared between most Casanovo commands""" +class _SharedFileIOParams(click.RichCommand): + """File IO options shared between most Casanovo commands""" def __init__(self, *args, **kwargs) -> None: """Define shared options.""" super().__init__(*args, **kwargs) self.params += [ - click.Option( - ("-m", "--model"), - help=""" - Either the model weights (.ckpt file) or a URL pointing to the - model weights file. If not provided, Casanovo will try to - download the latest release automatically. - """, - ), click.Option( ("-d", "--output_dir"), help="The destination directory for output files.", @@ -77,30 +69,44 @@ def __init__(self, *args, **kwargs) -> None: type=click.Path(dir_okay=False), ), click.Option( - ("-c", "--config"), - help=""" - The YAML configuration file overriding the default options. - """, - type=click.Path(exists=True, dir_okay=False), + ("-f", "--force_overwrite"), + help="Whether to overwrite output files.", + is_flag=True, + show_default=True, + default=False, ), click.Option( ("-v", "--verbosity"), - help=""" - Set the verbosity of console logging messages. Log files are - always set to 'debug'. - """, + help=( + "Set the verbosity of console logging messages." + " Log files are always set to 'debug'." + ), type=click.Choice( ["debug", "info", "warning", "error"], case_sensitive=False, ), default="info", ), + ] + + +class _SharedParams(_SharedFileIOParams): + """Options shared between main Casanovo commands""" + + def __init__(self, *args, **kwargs) -> None: + """Define shared options.""" + super().__init__(*args, **kwargs) + self.params += [ click.Option( - ("-f", "--force_overwrite"), - help="Whether to overwrite output files.", - is_flag=True, - show_default=True, - default=False, + ("-m", "--model"), + help="""Either the model weights (.ckpt file) or a URL pointing to + the model weights file. If not provided, Casanovo will try to + download the latest release automatically.""", + ), + click.Option( + ("-c", "--config"), + help="The YAML configuration file overriding the default options.", + type=click.Path(exists=True, dir_okay=False), ), ] @@ -335,38 +341,16 @@ def version() -> None: sys.stdout.write("\n".join(versions) + "\n") -@main.command() -@click.option( - "-d", - "--output_dir", - help="The destination directory for log and config file.", - type=click.Path(dir_okay=True), - required=False, -) -@click.option( - "-o", - "--output_root", - help="The root name for log and config file.", - type=click.Path(dir_okay=False), - required=False, -) -@click.option( - "-f", - "--force_overwrite", - help="Whether to overwrite output files.", - is_flag=True, - show_default=True, - default=False, -) +@main.command(cls=_SharedFileIOParams) def configure( - output_dir: str, output_root: str, force_overwrite: bool + output_dir: str, output_root: str, verbosity: str, force_overwrite: bool ) -> None: """Generate a Casanovo configuration file to customize. The casanovo configuration file is in the YAML format. """ output_path, _ = _setup_output( - output_dir, output_root, force_overwrite, "info" + output_dir, output_root, force_overwrite, verbosity ) config_fname = output_root if output_root is not None else "casanovo" config_fname = Path(config_fname).with_suffix(".yaml") @@ -375,7 +359,7 @@ def configure( config_path = str(output_path / config_fname) Config.copy_default(config_path) - logger.info(f"Wrote {config_path}\n") + logger.info(f"Wrote {config_path}") def setup_logging( diff --git a/tests/test_integration.py b/tests/test_integration.py index 84a1313a..d673af0f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,5 +1,6 @@ import functools import subprocess +import yaml from pathlib import Path import pyteomics.mztab @@ -215,7 +216,7 @@ def test_train_and_run( assert output_filename.is_file() -def test_auxilliary_cli(tmp_path, monkeypatch): +def test_auxilliary_cli(tmp_path, mgf_small, monkeypatch): """Test the secondary CLI commands""" run = functools.partial( CliRunner().invoke, casanovo.main, catch_exceptions=False @@ -231,5 +232,30 @@ def test_auxilliary_cli(tmp_path, monkeypatch): with pytest.raises(FileExistsError): run(["configure", "-o", "test.yaml"]) + with open("casanovo.yaml") as f_in: + config = yaml.safe_load(f_in) + + config["max_epochs"] = 1 + config["n_layers"] = 1 + + with open("small.yaml", "w") as f_out: + yaml.dump(config, f_out) + + train_args = [ + "train", + "--validation_peak_path", + str(mgf_small), + "--config", + "small.yaml", + "--output_dir", + str(tmp_path), + "--output_root", + "train", + str(mgf_small), + ] + + result = run(train_args) + assert result.exit_code == 0 + res = run("version") assert res.output diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 00617457..a59608fd 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -573,7 +573,6 @@ def test_calc_match_score(): def test_digest_fasta_cleave(tiny_fasta_file, residues_dict): - # No missed cleavages expected_normal = [ "ATSIPAR", @@ -1092,7 +1091,6 @@ def test_get_candidates(tiny_fasta_file, residues_dict): def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): - # Tide isotope error windows for 496.2, 2+: # 0: [980.481617, 1000.289326] # 1: [979.491114, 999.278813]