From f3696ca89a03a31fdc47e93d61aa250dda9f130a Mon Sep 17 00:00:00 2001 From: Wout Bittremieux Date: Wed, 2 Nov 2022 19:59:18 -0700 Subject: [PATCH] Automatically download model weights (#68) * Download model weights from GitHub release * Include dependencies * Update model usage documentation * Reformat with black * Download weights to the OS-specific app dir * Don't download weights if already in cache dir * Update model file instructions * Remove release notes from the README We have this information on the Releases page now. * Remove explicit model specification from example commands * Harmonize default parameters and config values As per discussion on Slack (https://noblelab.slack.com/archives/C01MXN4NWMP/p1659803053573279). * No need to specify config file by default This simplifies the examples that most users will want to use. * Simplify version matching regex * Remove depthcharge related tests The transformer tests only deal with depthcharge functionality and just seem copied from its repository. * Make sure that package data is included I.e. the config YAML file. * Remove obsolote (ppx) tests * Update integration test * Add MacOS support and support for Apple's MPS chips * Fail test but print version * Added n_worker fn and tests * Create split_version fn and add unit tests * Fix debugging unit test * Explicitly set version * Monkeypatch loaded version * Add device selector, so that on CPU-only runs the devices > 0 * Add windows patch * Fix typo * Revert * Use main process for data loading on Windows * Fix typo * Fix unit test * Fix devices for when num_workers == 0 * Fix devices for when num_workers == 0 * Minor README updates * Import reordering * Minor code and docstring reformatting * Test model weights retrieval * Fix getting the number of devices * Disable excessive Tensorboard deprecation warnings * Don't use worker threads on MacOS It crashes the DataLoader: https://github.com/pytorch/pytorch/issues/70344 * Warnings need to be ignored before import * Additional weights tests - Non-matching version - GitHub rate limit exceeded * Disable tests on MacOS * Include Python 3.10 as supported version Co-authored-by: William Fondrie --- .github/workflows/tests.yml | 2 +- README.md | 58 +++++++++---- casanovo/casanovo.py | 139 ++++++++++++++++++++++++++++++-- casanovo/config.yaml | 4 +- casanovo/denovo/dataloaders.py | 2 +- casanovo/denovo/model.py | 4 +- casanovo/denovo/model_runner.py | 55 +++++++++++-- casanovo/utils.py | 59 ++++++++++++++ setup.cfg | 6 ++ tests/conftest.py | 21 ++--- tests/test_integration.py | 22 +++++ tests/test_unit.py | 101 +++++++++++++++++++++-- 12 files changed, 423 insertions(+), 50 deletions(-) create mode 100644 casanovo/utils.py create mode 100644 tests/test_integration.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ea5a1eb8..1d7fe2f7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v2 diff --git a/README.md b/README.md index 192b857f..b939d97d 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,6 @@ If you use Casanovo in your work, please cite the following publication: - Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. *De novo* mass spectrometry peptide sequencing with a transformer model. in *Proceedings of the 39th International Conference on Machine Learning - ICML '22* vol. 162 25514–25522 (PMLR, 2022). [https://proceedings.mlr.press/v162/yilmaz22a.html](https://proceedings.mlr.press/v162/yilmaz22a.html) -Data and pre-trained model weights are available [on Zenodo](https://zenodo.org/record/6791263). - ## Documentation #### https://casanovo.readthedocs.io/en/latest/ @@ -18,7 +16,7 @@ Data and pre-trained model weights are available [on Zenodo](https://zenodo.org/ We recommend to run Casanovo in a dedicated **Anaconda** environment. This helps keep your environment for Casanovo and its dependencies separate from your other Python environments. -**This is especially helpful because Casanovo works within a specific range of Python versions (3.8 ≥ Python version > 3.10).** +**This is especially helpful because Casanovo works within a specific range of Python versions (3.8 ≥ Python version ≥ 3.10).** - Check out the [Windows](https://docs.anaconda.com/anaconda/install/windows/#), [MacOS](https://docs.anaconda.com/anaconda/install/mac-os/), and [Linux](https://docs.anaconda.com/anaconda/install/linux/) installation instructions. @@ -49,9 +47,9 @@ The base environment most likely will not work. ### Installation -Install Casanovo as a Python package from this repository (requires 3.8 ≥ [Python version] > 3.10 , dependencies will be installed automatically as needed): +Install Casanovo as a Python package from this repository (requires 3.8 ≥ [Python version] ≥ 3.10 , dependencies will be installed automatically as needed): -``` +``` sh pip install casanovo ``` @@ -60,12 +58,22 @@ Once installed, Casanovo can be used with a simple command line interface. All auxiliary data, model, and training-related parameters can be specified in a user created `.yaml` configuration file. See [`casanovo/config.yaml`](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml) for the default configuration that was used to obtain the reported results. + +### Model weights + +When running Casanovo in `denovo` or `eval` mode, Casanovo needs compatible pretrained model weights to make predictions. +Model weights can be found on the [Releases page](https://github.com/Noble-Lab/casanovo/releases) under the "Assets" for each release (file extension: .ckpt). +The model file can then be specified using the `--model` command-line parameter when executing Casanovo. +To assist users, if no model file is specified Casanovo will try to download and use a compatible model file automatically. + +Not all releases might have a model file included on the [Releases page](https://github.com/Noble-Lab/casanovo/releases), in which case model weights for alternative releases with the same major version number can be used. + ### Example commands - To run _de novo_ sequencing: ``` -casanovo --mode=denovo --model=path/to/pretrained.ckpt --peak_path=path/to/predict/spectra.mgf --config=path/to/config.yaml --output=path/to/output +casanovo --mode=denovo --peak_path=path/to/predict/spectra.mgf --output=path/to/output ``` Casanovo can predict peptide sequences for MS/MS data in mzML, mzXML, and MGF files. @@ -74,7 +82,7 @@ This will write peptide predictions for the given MS/MS spectra to the specified - To evaluate _de novo_ sequencing performance based on known spectrum annotations: ``` -casanovo --mode=eval --model=path/to/pretrained.ckpt --peak_path=path/to/test/annotated_spectra.mgf --config=path/to/config.yaml +casanovo --mode=eval --peak_path=path/to/test/annotated_spectra.mgf ``` To evaluate the peptide predictions, ground truth peptide labels need to be provided as an annotated MGF file. @@ -82,7 +90,7 @@ To evaluate the peptide predictions, ground truth peptide labels need to be prov - To train a model from scratch: ``` -casanovo --mode=train --peak_path=path/to/train/annotated_spectra.mgf --peak_path_val=path/to/validation/annotated_spectra.mgf --config=path/to/config.yaml +casanovo --mode=train --peak_path=path/to/train/annotated_spectra.mgf --peak_path_val=path/to/validation/annotated_spectra.mgf ``` Training and validation MS/MS data need to be provided as annotated MGF files. @@ -95,16 +103,13 @@ We will demonstrate how to use Casanovo using a small walkthrough example on a s The example MGF file is available at [`sample_data/sample_preprocessed_spectra.mgf`](https://github.com/Noble-Lab/casanovo/blob/main/sample_data/sample_preprocessed_spectra.mgf`). 1. Install Casanovo (see above for details). -2. Download the `casanovo_pretrained_model_weights.zip` from [Zenodo](https://zenodo.org/record/6791263). Place these models in a location that you can easily access and know the path of. - - We will be `using pretrained_excl_mouse.ckpt` for this job. -3. Copy the example `config.yaml` file into a location you can easily access. -4. Ensure you are in the proper anaconda environment by typing `conda activate casanovo_env`. (If you named your environment differently, type in that name instead.) -5. Run this command: +2. Ensure you are in the proper anaconda environment by typing `conda activate casanovo_env`. (If you named your environment differently, type in that name instead.) +3. Run this command: ``` -casanovo --mode=denovo --model=[PATH_TO]/pretrained_excl_mouse.ckpt --peak_path=[PATH_TO]/sample_preprocessed_spectra.mgf --config=[PATH_TO]/config.yaml +casanovo --mode=denovo --peak_path=[PATH_TO]/sample_preprocessed_spectra.mgf ``` -Make sure you use the proper filepath to the `pretrained_excl_mouse.ckpt` file. - - Note: If you want to get the output mzTab file in different location than the working directory, specify an alternative output location using the `--output` parameter. + +Note: If you want to store the output mzTab file in a different location than the current working directory, specify an alternative output location using the `--output` parameter. This job will take very little time to run (< 1 minute). @@ -127,8 +132,29 @@ Run the following command in your command prompt to see all possible command-lin casanovo --help ``` +Additionally, you can use a configuration file to fully customize Casanovo. +You can find the `config.yaml` configuration file that is used by default [here](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml). + **I get a "CUDA out of memory" error when trying to run Casanovo. Help!** This means that there was not enough (free) memory available on your GPU to run Casanovo, which is especially likely to happen when you are using a smaller, consumer-grade GPU. We recommend trying to decrease the `train_batch_size` or `predict_batch_size` options in the [config file](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml) (depending on whether the error occurred during `train` or `denovo` mode) to reduce the number of spectra that are processed simultaneously. Additionally, we recommend shutting down any other processes that may be running on the GPU, so that Casanovo can exclusively use the GPU. + +**How do I solve a "PermissionError: GitHub API rate limit exceeded" error when trying to run Casanovo?** + +When running Casanovo in `denovo` or `eval` mode, Casanovo needs compatible pretrained model weights to make predictions. +If no model weights file is specified using the `--model` command-line parameter, Casanovo will automatically try to download the latest compatible model file from GitHub and save it to its cache for subsequent use. +However, the GitHub API is limited to maximum 60 requests per hour per IP address. +Consequently, if Casanovo has been executed multiple times already, it might temporarily not be able to communicate with GitHub. +You can avoid this error by explicitly specifying the model file using the `--model` parameter. + +**I see "NotImplementedError: The operator 'aten::index.Tensor'..." when using a Mac with an Apple Silicon chip.** + +Casanovo can leverage Apple's Metal Performance Shaders (MPS) on newer Mac computers, which requires that the `PYTORCH_ENABLE_MPS_FALLBACK` is set to `1`: + +``` sh +export PYTORCH_ENABLE_MPS_FALLBACK=1 +``` + +This will need to be set with each new shell session, or you can add it to your `.bashrc` / `.zshrc` to set this environment variable by default. diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 1479aafd..0039a806 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -1,16 +1,27 @@ """The command line entry point for Casanovo.""" import datetime +import functools import logging import os +import re +import shutil import sys +import warnings +from typing import Optional, Tuple +warnings.filterwarnings("ignore", category=DeprecationWarning) + +import appdirs import click -import psutil +import github import pytorch_lightning as pl +import requests import torch +import tqdm import yaml from . import __version__ +from . import utils from .data import ms_io from .denovo import model_runner @@ -61,11 +72,11 @@ ) def main( mode: str, - model: str, + model: Optional[str], peak_path: str, - peak_path_val: str, - config: str, - output: str, + peak_path_val: Optional[str], + config: Optional[str], + output: Optional[str], ): """ \b @@ -105,10 +116,12 @@ def main( root.addHandler(file_handler) # Disable dependency non-critical log messages. logging.getLogger("depthcharge").setLevel(logging.INFO) + logging.getLogger("github").setLevel(logging.WARNING) logging.getLogger("h5py").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) # Read parameters from the config file. if config is None: @@ -163,13 +176,30 @@ def main( } # Add extra configuration options and scale by the number of GPUs. n_gpus = torch.cuda.device_count() - config["n_workers"] = len(psutil.Process().cpu_affinity()) + config["n_workers"] = utils.n_workers() if n_gpus > 1: - config["n_workers"] = config["n_workers"] // n_gpus config["train_batch_size"] = config["train_batch_size"] // n_gpus pl.utilities.seed.seed_everything(seed=config["random_seed"], workers=True) + # Download model weights if these were not specified (except when training). + if model is None and mode != "train": + try: + model = _get_model_weights() + except github.RateLimitExceededException: + logger.error( + "GitHub API rate limit exceeded while trying to download the " + "model weights. Please download compatible model weights " + "manually from the official Casanovo code website " + "(https://github.com/Noble-Lab/casanovo) and specify these " + "explicitly using the `--model` parameter when running " + "Casanovo." + ) + raise PermissionError( + "GitHub API rate limit exceeded while trying to download the " + "model weights" + ) from None + # Log the active configuration. logger.info("Casanovo version %s", str(__version__)) logger.debug("mode = %s", mode) @@ -198,5 +228,100 @@ def main( model_runner.train(peak_path, peak_path_val, model, config) +def _get_model_weights() -> str: + """ + Use cached model weights or download them from GitHub. + + If no weights file (extension: .ckpt) is available in the cache directory, + it will be downloaded from a release asset on GitHub. + Model weights are retrieved by matching release version. If no model weights + for an identical release (major, minor, patch), alternative releases with + matching (i) major and minor, or (ii) major versions will be used. + If no matching release can be found, no model weights will be downloaded. + + Note that the GitHub API is limited to 60 requests from the same IP per + hour. + + Returns + ------- + str + The name of the model weights file. + """ + cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False) + os.makedirs(cache_dir, exist_ok=True) + version = utils.split_version(__version__) + version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0 + # Try to find suitable model weights in the local cache. + for filename in os.listdir(cache_dir): + root, ext = os.path.splitext(filename) + if ext == ".ckpt": + file_version = tuple( + g for g in re.match(r".*_v(\d+)_(\d+)_(\d+)", root).groups() + ) + match = sum([i == j for i, j in zip(version, file_version)]) + if match > version_match[2]: + version_match = os.path.join(cache_dir, filename), None, match + # Provide the cached model weights if found. + if version_match[2] > 0: + logger.info( + "Model weights file %s retrieved from local cache", + version_match[0], + ) + return version_match[0] + # Otherwise try to find compatible model weights on GitHub. + else: + repo = github.Github().get_repo("Noble-Lab/casanovo") + # Find the best matching release with model weights provided as asset. + for release in repo.get_releases(): + rel_version = tuple( + g + for g in re.match( + r"v(\d+)\.(\d+)\.(\d+)", release.tag_name + ).groups() + ) + match = sum([i == j for i, j in zip(version, rel_version)]) + if match > version_match[2]: + for release_asset in release.get_assets(): + fn, ext = os.path.splitext(release_asset.name) + if ext == ".ckpt": + version_match = ( + os.path.join( + cache_dir, + f"{fn}_v{'_'.join(map(str, rel_version))}{ext}", + ), + release_asset.browser_download_url, + match, + ) + break + # Download the model weights if a matching release was found. + if version_match[2] > 0: + filename, url, _ = version_match + logger.info( + "Downloading model weights file %s from %s", filename, url + ) + r = requests.get(url, stream=True, allow_redirects=True) + r.raise_for_status() + file_size = int(r.headers.get("Content-Length", 0)) + desc = "(Unknown total file size)" if file_size == 0 else "" + r.raw.read = functools.partial(r.raw.read, decode_content=True) + with tqdm.tqdm.wrapattr( + r.raw, "read", total=file_size, desc=desc + ) as r_raw, open(filename, "wb") as f: + shutil.copyfileobj(r_raw, f) + return filename + else: + logger.error( + "No matching model weights for release v%s found, please " + "specify your model weights explicitly using the `--model` " + "parameter", + __version__, + ) + raise ValueError( + f"No matching model weights for release v{__version__} found, " + f"please specify your model weights explicitly using the " + f"`--model` parameter" + ) + + if __name__ == "__main__": main() diff --git a/casanovo/config.yaml b/casanovo/config.yaml index a1d432dc..96652f02 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -8,7 +8,7 @@ random_seed: 454 # Spectrum processing options. n_peaks: 150 -min_mz: 50.52564895 # 1.0005079 * 50.5 +min_mz: 50.0 max_mz: 2500.0 min_intensity: 0.01 remove_precursor_tol: 2.0 # Da @@ -21,7 +21,7 @@ dim_model: 512 n_head: 8 dim_feedforward: 1024 n_layers: 9 -dropout: 0 +dropout: 0.0 dim_intensity: custom_encoder: max_length: 100 diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 28e481ca..2ee2f8f5 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -54,7 +54,7 @@ def __init__( test_index: Optional[AnnotatedSpectrumIndex] = None, batch_size: int = 128, n_peaks: Optional[int] = 150, - min_mz: float = 140.0, + min_mz: float = 50.0, max_mz: float = 2500.0, min_intensity: float = 0.01, remove_precursor_tol: float = 2.0, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index e2a6f09b..934169f7 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -79,10 +79,10 @@ class Spec2Pep(pl.LightningModule, ModelMixin): def __init__( self, - dim_model: int = 128, + dim_model: int = 512, n_head: int = 8, dim_feedforward: int = 1024, - n_layers: int = 1, + n_layers: int = 9, dropout: float = 0.0, dim_intensity: Optional[int] = None, custom_encoder: Optional[SpectrumEncoder] = None, diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 1c3d7d27..fd80df1a 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -2,16 +2,19 @@ model.""" import glob import logging +import operator import os import tempfile import uuid -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import pytorch_lightning as pl +import torch from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex from pytorch_lightning.strategies import DDPStrategy +from .. import utils from ..data import ms_io from ..denovo.dataloaders import DeNovoDataModule from ..denovo.model import Spec2Pep @@ -146,15 +149,16 @@ def _execute_existing( batch_size=config["predict_batch_size"], ) loaders.setup(stage="test", annotated=annotated) + # Create the Trainer object. trainer = pl.Trainer( accelerator="auto", auto_select_gpus=True, - devices=-1, + devices=_get_devices(), logger=config["logger"], max_epochs=config["max_epochs"], num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=DDPStrategy(find_unused_parameters=False, static_graph=True), + strategy=_get_strategy(), ) # Run the model with/without validation. run_trainer = trainer.validate if annotated else trainer.predict @@ -290,15 +294,16 @@ def train( ] else: callbacks = None + trainer = pl.Trainer( accelerator="auto", auto_select_gpus=True, callbacks=callbacks, - devices=-1, + devices=_get_devices(), logger=config["logger"], max_epochs=config["max_epochs"], num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=DDPStrategy(find_unused_parameters=False, static_graph=True), + strategy=_get_strategy(), ) # Train the model. trainer.fit( @@ -336,3 +341,43 @@ def _get_peak_filenames( for fn in glob.glob(path, recursive=True) if os.path.splitext(fn.lower())[1] in supported_ext ] + + +def _get_strategy() -> Optional[DDPStrategy]: + """ + Get the strategy for the Trainer. + + The DDP strategy works best when multiple GPUs are used. It can work for + CPU-only, but definitely fails using MPS (the Apple Silicon chip) due to + Gloo. + + Returns + ------- + Optional[DDPStrategy] + The strategy parameter for the Trainer. + """ + if torch.cuda.device_count() > 1: + return DDPStrategy(find_unused_parameters=False, static_graph=True) + + return None + + +def _get_devices() -> Union[int, str]: + """ + Get the number of GPUs/CPUs for the Trainer to use. + + Returns + ------- + Union[int, str] + The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning + determine the appropriate number of devices. + """ + if any( + operator.attrgetter(device + ".is_available")(torch)() + for device in ["cuda", "backends.mps"] + ): + return -1 + elif not (n_workers := utils.n_workers()): + return "auto" + else: + return n_workers diff --git a/casanovo/utils.py b/casanovo/utils.py new file mode 100644 index 00000000..cca67747 --- /dev/null +++ b/casanovo/utils.py @@ -0,0 +1,59 @@ +"""Small utility functions""" +import os +import platform +import re +from typing import Tuple + +import psutil +import torch + + +def n_workers() -> int: + """ + Get the number of workers to use for data loading. + + This is the maximum number of CPUs allowed for the process, scaled for the + number of GPUs being used. + + On Windows and MacOS, we only use the main process. See: + https://discuss.pytorch.org/t/errors-when-using-num-workers-0-in-dataloader/97564/4 + https://github.com/pytorch/pytorch/issues/70344 + + Returns + ------- + int + The number of workers. + """ + # Windows or MacOS: no multiprocessing. + if platform.system() in ["Windows", "Darwin"]: + return 0 + # Linux: scale the number of workers by the number of GPUs (if present). + try: + n_cpu = len(psutil.Process().cpu_affinity()) + except AttributeError: + n_cpu = os.cpu_count() + return ( + n_cpu // n_gpu if (n_gpu := torch.cuda.device_count()) > 1 else n_cpu + ) + + +def split_version(version: str) -> Tuple[str, str, str]: + """ + Split the version into its semantic versioning components. + + Parameters + ---------- + version : str + The version number. + + Returns + ------- + major : str + The major release. + minor : str + The minor release. + patch : str + The patch release. + """ + version_regex = re.compile(r"(\d+)\.(\d+)\.*(\d*)(?:.dev\d+.+)?") + return tuple(g for g in version_regex.match(version).groups()) diff --git a/setup.cfg b/setup.cfg index af8a2bab..f457c42b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,17 +19,23 @@ classifiers = [options] packages = find: +include_package_data = True python_requires = >=3.8 install_requires = + appdirs click depthcharge-ms>=0.0.1 numpy pandas psutil + PyGithub pytorch-lightning>=1.7 + PyYAML + requests scikit-learn spectrum_utils torch>=1.9 + tqdm [options.extras_require] docs = diff --git a/tests/conftest.py b/tests/conftest.py index 2521f51e..3f78eebc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ -"""Setup tests for casanovo""" +"""Fixtures used for testing.""" import pytest - -from pyteomics.mass import calculate_mass import numpy as np +from pyteomics.mass import calculate_mass @pytest.fixture @@ -14,7 +13,9 @@ def mgf_small(tmp_path): def _create_mgf(peptides, mgf_file, random_state=42): - """Create a fake MGF file from one or more peptides. + """ + Create a fake MGF file from one or more peptides. + Parameters ---------- peptides : str or list of str @@ -22,10 +23,11 @@ def _create_mgf(peptides, mgf_file, random_state=42): mgf_file : Path The MGF file to create. random_state : int or numpy.random.Generator, optional - The random seed. The charge states are chosen to be - 2 or 3 randomly. + The random seed. The charge states are chosen to be 2 or 3 randomly. + Returns ------- + mgf_file : Path """ rng = np.random.default_rng(random_state) entries = [_create_mgf_entry(p, rng.choice([2, 3])) for p in peptides] @@ -36,13 +38,16 @@ def _create_mgf(peptides, mgf_file, random_state=42): def _create_mgf_entry(peptide, charge=2): - """Create a MassIVE-KB style MGF entry for a single PSM. + """ + Create a MassIVE-KB style MGF entry for a single PSM. + Parameters ---------- peptide : str A peptide sequence. charge : int, optional The peptide charge state. + Returns ------- str @@ -56,12 +61,10 @@ def _create_mgf_entry(peptide, charge=2): frags.append( str(calculate_mass(b_pep, charge=zstate, ion_type="b")) ) - y_pep = peptide[idx:] frags.append( str(calculate_mass(y_pep, charge=zstate, ion_type="y")) ) - frag_string = " 1\n".join(frags) + " 1" mgf = [ diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 00000000..22eac84d --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,22 @@ +import casanovo +from casanovo import casanovo + + +def test_denovo(mgf_small, tmp_path, monkeypatch): + # We can use this to explicitly test different versions. + monkeypatch.setattr(casanovo, "__version__", "3.0.1") + + # Predict on a small MGF file and verify that the output file exists. + output_filename = tmp_path / "test.mztab" + casanovo.main( + [ + "--mode", + "denovo", + "--peak_path", + str(mgf_small), + "--output", + str(output_filename), + ], + standalone_mode=False, + ) + assert output_filename.is_file() diff --git a/tests/test_unit.py b/tests/test_unit.py index c7c979cd..660fe231 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -1,18 +1,105 @@ -"""Test that setuptools-scm is working correctly""" -import casanovo +import os +import platform +import tempfile + +import github +import pytest + +from casanovo import casanovo +from casanovo import utils from casanovo.denovo.model import Spec2Pep def test_version(): - """Check that the version is not None""" + """Check that the version is not None.""" assert casanovo.__version__ is not None +def test_n_workers(monkeypatch): + """Check that n_workers is correct without a GPU.""" + monkeypatch.setattr("torch.cuda.is_available", lambda: False) + cpu_fun = lambda x: ["foo"] * 31 + + with monkeypatch.context() as mnk: + mnk.setattr("psutil.Process.cpu_affinity", cpu_fun, raising=False) + expected = 0 if platform.system() in ["Windows", "Darwin"] else 31 + assert utils.n_workers() == expected + + with monkeypatch.context() as mnk: + mnk.delattr("psutil.Process.cpu_affinity", raising=False) + mnk.setattr("os.cpu_count", lambda: 41) + expected = 0 if platform.system() in ["Windows", "Darwin"] else 41 + assert utils.n_workers() == expected + + with monkeypatch.context() as mnk: + mnk.setattr("torch.cuda.device_count", lambda: 4) + mnk.setattr("psutil.Process.cpu_affinity", cpu_fun, raising=False) + expected = 0 if platform.system() in ["Windows", "Darwin"] else 7 + assert utils.n_workers() == expected + + with monkeypatch.context() as mnk: + mnk.delattr("psutil.Process.cpu_affinity", raising=False) + mnk.delattr("os.cpu_count") + if platform.system() not in ["Windows", "Darwin"]: + with pytest.raises(AttributeError): + utils.n_workers() + else: + assert utils.n_workers() == 0 + + +def test_split_version(): + """Test that splitting the version number works as expected.""" + version = utils.split_version("2.0.1") + assert version == ("2", "0", "1") + + version = utils.split_version("0.1.dev1+g39f8c53") + assert version == ("0", "1", "") + + version = utils.split_version("3.0.1.dev10282blah") + assert version == ("3", "0", "1") + + +def test_get_model_weights(monkeypatch): + """ + Test that model weights can be downloaded from GitHub or used from the + cache. + """ + # Model weights for fully matching version, minor matching version, major + # matching version. + for version in ["3.0.0", "3.0.999", "3.999.999"]: + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr(casanovo, "__version__", version) + mnk.setattr( + "appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir + ) + + filename = os.path.join(tmp_dir, "casanovo_massivekb_v3_0_0.ckpt") + assert not os.path.isfile(filename) + assert casanovo._get_model_weights() == filename + assert os.path.isfile(filename) + assert casanovo._get_model_weights() == filename + + # Impossible to find model weights for non-matching version. + with monkeypatch.context() as mnk: + mnk.setattr(casanovo, "__version__", "999.999.999") + with pytest.raises(ValueError): + casanovo._get_model_weights() + + # Test GitHub API rate limit. + def request(self, *args, **kwargs): + raise github.RateLimitExceededException( + 403, "API rate limit exceeded", None + ) + + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) + mnk.setattr("github.Requester.Requester.requestJsonAndCheck", request) + with pytest.raises(github.RateLimitExceededException): + casanovo._get_model_weights() + + def test_tensorboard(): - """Check that the version is not None""" - model = Spec2Pep( - tb_summarywriter="test_path", - ) + model = Spec2Pep(tb_summarywriter="test_path") assert model.tb_summarywriter is not None model = Spec2Pep()