From 5710530b1761423837b4b8ef66f77a34a9f809cd Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:20:08 -0700 Subject: [PATCH 01/18] Overhaul runner --- .github/workflows/tests.yml | 2 +- .gitignore | 2 + casanovo/casanovo.py | 31 +- casanovo/config.py | 9 +- casanovo/config.yaml | 12 +- casanovo/denovo/__init__.py | 1 + casanovo/denovo/dataloaders.py | 26 +- casanovo/denovo/model.py | 32 +- casanovo/denovo/model_runner.py | 684 ++++++++++++++++---------------- pyproject.toml | 4 +- tests/conftest.py | 23 ++ tests/test_integration.py | 81 +++- tests/unit_tests/test_config.py | 4 +- 13 files changed, 498 insertions(+), 413 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d7fe2f7..ea5a1eb8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest] + os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index 32202470..aa8178a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Test stuff: test_path/ +lightning_logs/ +envs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 72b02057..98bc986c 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -18,12 +18,12 @@ import torch import tqdm import yaml -from pytorch_lightning.lite import LightningLite +from lightning.pytorch import seed_everything from . import __version__ from . import utils from .data import ms_io -from .denovo import model_runner +from .denovo import ModelRunner from .config import Config logger = logging.getLogger("casanovo") @@ -52,11 +52,13 @@ required=True, help="The file path with peak files for predicting peptide sequences or " "training Casanovo.", + multiple=True, ) @click.option( "--peak_path_val", help="The file path with peak files to be used as validation data during " "training.", + multiple=True, ) @click.option( "--config", @@ -127,7 +129,7 @@ def main( # Read parameters from the config file. config = Config(config) - LightningLite.seed_everything(seed=config["random_seed"], workers=True) + seed_everything(seed=config["random_seed"], workers=True) # Download model weights if these were not specified (except when training). if model is None and mode != "train": @@ -159,18 +161,17 @@ def main( logger.debug("%s = %s", str(key), str(value)) # Run Casanovo in the specified mode. - if mode == "denovo": - logger.info("Predict peptide sequences with Casanovo.") - writer = ms_io.MztabWriter(f"{output}.mztab") - writer.set_metadata(config, model=model, config_filename=config.file) - model_runner.predict(peak_path, model, config, writer) - writer.save() - elif mode == "eval": - logger.info("Evaluate a trained Casanovo model.") - model_runner.evaluate(peak_path, model, config) - elif mode == "train": - logger.info("Train the Casanovo model.") - model_runner.train(peak_path, peak_path_val, model, config) + with ModelRunner(config, model) as model_runner: + if mode == "denovo": + logger.info("Predict peptide sequences with Casanovo.") + model_runner.predict(peak_path, output) + model_runner.writer.save() + elif mode == "eval": + logger.info("Evaluate a trained Casanovo model.") + model_runner.evaluate(peak_path) + elif mode == "train": + logger.info("Train the Casanovo model.") + model_runner.train(peak_path, peak_path_val) def _get_model_weights() -> str: diff --git a/casanovo/config.py b/casanovo/config.py index 4dc93c26..bfff3685 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -67,7 +67,8 @@ class Config: model_save_folder_path=str, save_weights_only=bool, every_n_train_steps=int, - no_gpu=bool, + accelerator=str, + devices=int, ) def __init__(self, config_file: Optional[str] = None): @@ -86,13 +87,7 @@ def __init__(self, config_file: Optional[str] = None): for key, val in self._config_types.items(): self.validate_param(key, val) - # Add extra configuration options and scale by the number of GPUs. - n_gpus = 0 if self["no_gpu"] else torch.cuda.device_count() self._params["n_workers"] = utils.n_workers() - if n_gpus > 1: - self._params["train_batch_size"] = ( - self["train_batch_size"] // n_gpus - ) def __getitem__(self, param: str) -> Union[int, bool, str, Tuple, Dict]: """Retrieve a parameter""" diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 0e5c6b95..4f84d1e8 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -106,7 +106,7 @@ top_match: 1 # Object for logging training progress logger: # Max number of training epochs -max_epochs: 30 +max_epochs: 1000 # Number of validation steps to run before training begins num_sanity_val_steps: 0 # Set to "False" to further train a pre-trained Casanovo model @@ -119,5 +119,11 @@ model_save_folder_path: "" save_weights_only: True # Model validation and checkpointing frequency in training steps every_n_train_steps: 50_000 -# Disable usage of a GPU (including Apple MPS): -no_gpu: False +# The hardware accelerator to use. Must be one of: +# “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, or “auto” +accelerator: "auto" +# The devices to use. Can be set to a positive number int, +# or the value -1 to indicate all available devices should be used, +# If left empty, the appropriate number will be automatically +# selected for automatic selected on the chosen accelerator. +devices: diff --git a/casanovo/denovo/__init__.py b/casanovo/denovo/__init__.py index e69de29b..da194f1b 100644 --- a/casanovo/denovo/__init__.py +++ b/casanovo/denovo/__init__.py @@ -0,0 +1 @@ +from .model_runner import ModelRunner diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 2ee2f8f5..15d21b07 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -3,9 +3,9 @@ import os from typing import List, Optional, Tuple -import numpy as np -import pytorch_lightning as pl import torch +import numpy as np +import lightning.pytorch as pl from depthcharge.data import AnnotatedSpectrumIndex from ..data.datasets import AnnotatedSpectrumDataset, SpectrumDataset @@ -52,7 +52,8 @@ def __init__( train_index: Optional[AnnotatedSpectrumIndex] = None, valid_index: Optional[AnnotatedSpectrumIndex] = None, test_index: Optional[AnnotatedSpectrumIndex] = None, - batch_size: int = 128, + train_batch_size: int = 128, + eval_batch_size: int = 1028, n_peaks: Optional[int] = 150, min_mz: float = 50.0, max_mz: float = 2500.0, @@ -65,7 +66,8 @@ def __init__( self.train_index = train_index self.valid_index = valid_index self.test_index = test_index - self.batch_size = batch_size + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size self.n_peaks = n_peaks self.min_mz = min_mz self.max_mz = max_mz @@ -119,7 +121,9 @@ def setup(self, stage: str = None, annotated: bool = True) -> None: self.test_dataset = make_dataset(self.test_index) def _make_loader( - self, dataset: torch.utils.data.Dataset + self, + dataset: torch.utils.data.Dataset, + batch_size: int, ) -> torch.utils.data.DataLoader: """ Create a PyTorch DataLoader. @@ -128,6 +132,8 @@ def _make_loader( ---------- dataset : torch.utils.data.Dataset A PyTorch Dataset. + batch_size : int + The batch size to use. Returns ------- @@ -136,7 +142,7 @@ def _make_loader( """ return torch.utils.data.DataLoader( dataset, - batch_size=self.batch_size, + batch_size=batch_size, collate_fn=prepare_batch, pin_memory=True, num_workers=self.n_workers, @@ -144,19 +150,19 @@ def _make_loader( def train_dataloader(self) -> torch.utils.data.DataLoader: """Get the training DataLoader.""" - return self._make_loader(self.train_dataset) + return self._make_loader(self.train_dataset, self.train_batch_size) def val_dataloader(self) -> torch.utils.data.DataLoader: """Get the validation DataLoader.""" - return self._make_loader(self.valid_dataset) + return self._make_loader(self.valid_dataset, self.eval_batch_size) def test_dataloader(self) -> torch.utils.data.DataLoader: """Get the test DataLoader.""" - return self._make_loader(self.test_dataset) + return self._make_loader(self.test_dataset, self.eval_batch_size) def predict_dataloader(self) -> torch.utils.data.DataLoader: """Get the predict DataLoader.""" - return self._make_loader(self.test_dataset) + return self._make_loader(self.test_dataset, self.eval_batch_size) def prepare_batch( diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 8b3ae3e0..a387929b 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -7,9 +7,9 @@ import depthcharge.masses import einops -import numpy as np -import pytorch_lightning as pl import torch +import numpy as np +import lightning.pytorch as pl from torch.utils.tensorboard import SummaryWriter from depthcharge.components import ModelMixin, PeptideDecoder, SpectrumEncoder @@ -724,8 +724,8 @@ def training_step( pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1) loss = self.celoss(pred, truth.flatten()) self.log( - "CELoss", - {mode: loss.detach()}, + f"{mode}_CELoss", + loss.detach(), on_step=False, on_epoch=True, sync_dist=True, @@ -766,12 +766,10 @@ def validation_step( log_args = dict(on_step=False, on_epoch=True, sync_dist=True) self.log( "Peptide precision at coverage=1", - {"valid": pep_precision}, + pep_precision, **log_args, ) - self.log( - "AA precision at coverage=1", {"valid": aa_precision}, **log_args - ) + self.log("AA precision at coverage=1", aa_precision, **log_args) return loss @@ -824,7 +822,7 @@ def on_train_epoch_end(self) -> None: """ Log the training loss at the end of each epoch. """ - train_loss = self.trainer.callback_metrics["CELoss"]["train"].detach() + train_loss = self.trainer.callback_metrics["train_CELoss"].detach() metrics = { "step": self.trainer.global_step, "train": train_loss, @@ -839,19 +837,21 @@ def on_validation_epoch_end(self) -> None: callback_metrics = self.trainer.callback_metrics metrics = { "step": self.trainer.global_step, - "valid": callback_metrics["CELoss"]["valid"].detach(), + "valid": callback_metrics["valid_CELoss"].detach(), "valid_aa_precision": callback_metrics[ "AA precision at coverage=1" - ]["valid"].detach(), + ].detach(), "valid_pep_precision": callback_metrics[ "Peptide precision at coverage=1" - ]["valid"].detach(), + ].detach(), } self._history.append(metrics) self._log_history() - def on_predict_epoch_end( - self, results: List[List[Tuple[np.ndarray, List[str], torch.Tensor]]] + def on_predict_batch_end( + self, + outputs: List[Tuple[np.ndarray, List[str], torch.Tensor]], + *args, ) -> None: """ Write the predicted peptide sequences and amino acid scores to the @@ -867,9 +867,7 @@ def on_predict_epoch_end( peptide, peptide_score, aa_scores, - ) in itertools.chain.from_iterable( - itertools.chain.from_iterable(results) - ): + ) in outputs: if len(peptide) == 0: continue self.out_writer.psms.append( diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index fb5deeba..09bc6ca4 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -6,15 +6,17 @@ import os import tempfile import uuid +from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np -import pytorch_lightning as pl import torch +import lightning.pytorch as pl +from lightning.pytorch.strategies import DDPStrategy from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex -from pytorch_lightning.strategies import DDPStrategy from .. import utils +from ..config import Config from ..data import ms_io from ..denovo.dataloaders import DeNovoDataModule from ..denovo.model import Spec2Pep @@ -23,307 +25,350 @@ logger = logging.getLogger("casanovo") -def predict( - peak_path: str, - model_filename: str, - config: Dict[str, Any], - out_writer: ms_io.MztabWriter, -) -> None: - """ - Predict peptide sequences with a trained Casanovo model. +class ModelRunner: + """A class to run Casanovo models. Parameters ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - out_writer : ms_io.MztabWriter - The mzTab writer to export the prediction results. - """ - _execute_existing(peak_path, model_filename, config, False, out_writer) - - -def evaluate( - peak_path: str, model_filename: str, config: Dict[str, Any] -) -> None: + config : Config object + The casanovo configuration. + model_filename : str, optional + The model filename is required for eval and de novo modes, + but not for training a model from scratch. """ - Evaluate peptide sequence predictions from a trained Casanovo model. - - Parameters - ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - """ - _execute_existing(peak_path, model_filename, config, True) + def __init__( + self, + config: Config, + model_filename: Optional[str] = None, + ) -> None: + """Initialize a ModelRunner""" + self.config = config + self.model_filename = model_filename + + # Initialized later: + self.tmp_dir = None + self.trainer = None + self.model = None + self.loaders = None + self.writer = None + + # Configure checkpoints. + if config.save_model: + self.callbacks = [ + pl.callbacks.ModelCheckpoint( + dirpath=config.model_save_folder_path, + save_top_k=-1, + save_weights_only=config.save_weights_only, + every_n_train_steps=config.every_n_train_steps, + ) + ] + else: + self.callbacks = None + + def __enter__(self): + """Enter the context manager""" + self.tmp_dir = tempfile.TemporaryDirectory() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Cleanup on exit""" + self.tmp_dir.cleanup() + self.tmp_dir = None + + def train( + self, + train_peak_path: Iterable[str], + valid_peak_path: Iterable[str], + ) -> None: + """Train the Casanovo model. + + Parameters + ---------- + train_peak_path : iterable of str + The path to the MS data files for training. + valid_peak_path : iterable of str + The path to the MS data files for validation. + + Returns + ------- + self + """ + self.initialize_trainer(train=True) + self.initialize_model(train=True) + + train_index = self._get_index(train_peak_path, True, "training") + valid_index = self._get_index(valid_peak_path, True, "validation") + self.initialize_data_module(train_index, valid_index) + self.loaders.setup() + + self.trainer.fit( + self.model, + self.loaders.train_dataloader(), + self.loaders.val_dataloader(), + ) -def _execute_existing( - peak_path: str, - model_filename: str, - config: Dict[str, Any], - annotated: bool, - out_writer: Optional[ms_io.MztabWriter] = None, -) -> None: - """ - Predict peptide sequences with a trained Casanovo model with/without - evaluation. + def evaluate(self, peak_path: Iterable[str]) -> None: + """Evaluate peptide sequence preditions from a trained Casanovo model. + + Parameters + ---------- + peak_path : iterable of str + The path with MS data files for predicting peptide sequences. + + Returns + ------- + self + """ + self.initialize_trainer(train=False) + self.initialize_model(train=False) + + test_index = self._get_index(peak_path, True, "evaluation") + self.initialize_data_module(test_index=test_index) + self.loaders.setup(stage="test", annotated=True) + + self.trainer.validate(self.model, self.loaders.test_dataloader()) + + def predict(self, peak_path: Iterable[str], output: str) -> None: + """Predict peptide sequences with a trained Casanovo model. + + Parameters + ---------- + peak_path : iterable of str + The path with the MS data files for predicting peptide sequences. + output : str + Where should the output be saved? + + Returns + ------- + self + """ + self.writer = ms_io.MztabWriter(f"{output}.mztab") + self.writer.set_metadata( + self.config, + model=str(self.model_filename), + config_filename=self.config.file, + ) - Parameters - ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - annotated : bool - Whether the input peak files are annotated (execute in evaluation mode) - or not (execute in prediction mode only). - out_writer : Optional[ms_io.MztabWriter] - The mzTab writer to export the prediction results. - """ - # Load the trained model. - if not os.path.isfile(model_filename): - logger.error( - "Could not find the trained model weights at file %s", - model_filename, + self.initialize_trainer(train=False) + self.initialize_model(train=False) + + test_index = self._get_index(peak_path, False, "") + self.writer.set_ms_run(test_index.ms_files) + self.initialize_data_module(test_index=test_index) + self.loaders.setup(stage="test", annotated=False) + + self.trainer.predict(self.model, self.loaders.test_dataloader()) + + def initialize_trainer(self, train: bool) -> None: + """Initialize the lightning Trainer. + + Parameters + ---------- + train : bool + Determines whether to set the trainer up for model training + or evaluation / inference. + """ + trainer_cfg = dict( + accelerator=self.config.accelerator, + devices=1, + logger=self.config.logger, ) - raise FileNotFoundError("Could not find the trained model weights") - model = Spec2Pep().load_from_checkpoint( - model_filename, - dim_model=config["dim_model"], - n_head=config["n_head"], - dim_feedforward=config["dim_feedforward"], - n_layers=config["n_layers"], - dropout=config["dropout"], - dim_intensity=config["dim_intensity"], - custom_encoder=config["custom_encoder"], - max_length=config["max_length"], - residues=config["residues"], - max_charge=config["max_charge"], - precursor_mass_tol=config["precursor_mass_tol"], - isotope_error_range=config["isotope_error_range"], - min_peptide_len=config["min_peptide_len"], - n_beams=config["n_beams"], - top_match=config["top_match"], - n_log=config["n_log"], - out_writer=out_writer, - ) - # Read the MS/MS spectra for which to predict peptide sequences. - if annotated: - peak_ext = (".mgf", ".h5", ".hdf5") - else: - peak_ext = (".mgf", ".mzml", ".mzxml", ".h5", ".hdf5") - if len(peak_filenames := _get_peak_filenames(peak_path, peak_ext)) == 0: - logger.error("Could not find peak files from %s", peak_path) - raise FileNotFoundError("Could not find peak files") - else: - out_writer.set_ms_run(peak_filenames) - peak_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in peak_filenames] - ) - if peak_is_index and len(peak_filenames) > 1: - logger.error("Multiple HDF5 spectrum indexes specified") - raise ValueError("Multiple HDF5 spectrum indexes specified") - tmp_dir = tempfile.TemporaryDirectory() - if peak_is_index: - idx_filename, peak_filenames = peak_filenames[0], None - else: - idx_filename = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - SpectrumIdx = AnnotatedSpectrumIndex if annotated else SpectrumIndex - valid_charge = np.arange(1, config["max_charge"] + 1) - index = SpectrumIdx( - idx_filename, peak_filenames, valid_charge=valid_charge - ) - # Initialize the data loader. - loaders = DeNovoDataModule( - test_index=index, - n_peaks=config["n_peaks"], - min_mz=config["min_mz"], - max_mz=config["max_mz"], - min_intensity=config["min_intensity"], - remove_precursor_tol=config["remove_precursor_tol"], - n_workers=config["n_workers"], - batch_size=config["predict_batch_size"], - ) - loaders.setup(stage="test", annotated=annotated) - - # Create the Trainer object. - trainer = pl.Trainer( - accelerator="auto", - auto_select_gpus=True, - devices=_get_devices(config["no_gpu"]), - logger=config["logger"], - max_epochs=config["max_epochs"], - num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=_get_strategy(), - ) - # Run the model with/without validation. - run_trainer = trainer.validate if annotated else trainer.predict - run_trainer(model, loaders.test_dataloader()) - # Clean up temporary files. - tmp_dir.cleanup() - - -def train( - peak_path: str, - peak_path_val: str, - model_filename: str, - config: Dict[str, Any], -) -> None: - """ - Train a Casanovo model. - The model can be trained from scratch or by continuing training an existing - model. + if self.train: + if self.config.devices is None: + devices = "auto" + else: + devices = self.config.devices + + additional_cfg = dict( + devices=devices, + callbacks=self.callbacks, + enable_checkpointing=self.config.save_model, + max_epochs=self.config.max_epochs, + num_sanity_val_steps=self.config.num_sanity_val_steps, + strategy=self._get_strategy(), + val_check_interval=self.config.every_n_train_steps, + ) + trainer_cfg.update(additional_cfg) + + self.trainer = pl.Trainer(**trainer_cfg) + + def initialize_model(self, train: bool) -> None: + """Initialize the Casanovo model. + + Parameters + ---------- + train : bool + Determines whether to set the model up for model training + or evaluation / inference. + """ + model_params = dict( + dim_model=self.config.dim_model, + n_head=self.config.n_head, + dim_feedforward=self.config.dim_feedforward, + n_layers=self.config.n_layers, + dropout=self.config.dropout, + dim_intensity=self.config.dim_intensity, + custom_encoder=self.config.custom_encoder, + max_length=self.config.max_length, + residues=self.config.residues, + max_charge=self.config.max_charge, + precursor_mass_tol=self.config.precursor_mass_tol, + isotope_error_range=self.config.isotope_error_range, + n_beams=self.config.n_beams, + top_match=self.config.top_match, + n_log=self.config.n_log, + tb_summarywriter=self.config.tb_summarywriter, + warmup_iters=self.config.warmup_iters, + max_iters=self.config.max_iters, + lr=self.config.learning_rate, + weight_decay=self.config.weight_decay, + out_writer=self.writer, + ) - Parameters - ---------- - peak_path : str - The path with peak files to be used as training data. - peak_path_val : str - The path with peak files to be used as validation data. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - """ - # Read the MS/MS spectra to use for training and validation. - ext = (".mgf", ".h5", ".hdf5") - if len(train_filenames := _get_peak_filenames(peak_path, ext)) == 0: - logger.error("Could not find training peak files from %s", peak_path) - raise FileNotFoundError("Could not find training peak files") - train_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in train_filenames] - ) - if train_is_index and len(train_filenames) > 1: - logger.error("Multiple training HDF5 spectrum indexes specified") - raise ValueError("Multiple training HDF5 spectrum indexes specified") - if ( - peak_path_val is None - or len(val_filenames := _get_peak_filenames(peak_path_val, ext)) == 0 - ): - logger.error( - "Could not find validation peak files from %s", peak_path_val + from_scratch = ( + self.config.train_from_scratch, + self.model_filename is None, ) - raise FileNotFoundError("Could not find validation peak files") - val_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in val_filenames] - ) - if val_is_index and len(val_filenames) > 1: - logger.error("Multiple validation HDF5 spectrum indexes specified") - raise ValueError("Multiple validation HDF5 spectrum indexes specified") - tmp_dir = tempfile.TemporaryDirectory() - if train_is_index: - train_idx_fn, train_filenames = train_filenames[0], None - else: - train_idx_fn = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - valid_charge = np.arange(1, config["max_charge"] + 1) - train_index = AnnotatedSpectrumIndex( - train_idx_fn, train_filenames, valid_charge=valid_charge - ) - if val_is_index: - val_idx_fn, val_filenames = val_filenames[0], None - else: - val_idx_fn = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - val_index = AnnotatedSpectrumIndex( - val_idx_fn, val_filenames, valid_charge=valid_charge - ) - # Initialize the data loaders. - dataloader_params = dict( - batch_size=config["train_batch_size"], - n_peaks=config["n_peaks"], - min_mz=config["min_mz"], - max_mz=config["max_mz"], - min_intensity=config["min_intensity"], - remove_precursor_tol=config["remove_precursor_tol"], - n_workers=config["n_workers"], - ) - train_loader = DeNovoDataModule( - train_index=train_index, **dataloader_params - ) - train_loader.setup() - val_loader = DeNovoDataModule(valid_index=val_index, **dataloader_params) - val_loader.setup() - # Initialize the model. - model_params = dict( - dim_model=config["dim_model"], - n_head=config["n_head"], - dim_feedforward=config["dim_feedforward"], - n_layers=config["n_layers"], - dropout=config["dropout"], - dim_intensity=config["dim_intensity"], - custom_encoder=config["custom_encoder"], - max_length=config["max_length"], - residues=config["residues"], - max_charge=config["max_charge"], - precursor_mass_tol=config["precursor_mass_tol"], - isotope_error_range=config["isotope_error_range"], - n_beams=config["n_beams"], - top_match=config["top_match"], - n_log=config["n_log"], - tb_summarywriter=config["tb_summarywriter"], - warmup_iters=config["warmup_iters"], - max_iters=config["max_iters"], - lr=config["learning_rate"], - weight_decay=config["weight_decay"], - ) - if config["train_from_scratch"]: - model = Spec2Pep(**model_params) - else: - if not os.path.isfile(model_filename): + if train and any(from_scratch): + self.model = Spec2Pep(**model_params) + return + elif self.model_filename is None: + logger.error("A model file must be proided") + raise ValueError("A model file must be provided") + + if not self.model_filename.exists(): logger.error( - "Could not find the model weights at file %s to continue " - "training", + "Could not find the model weights at file %s", model_filename, ) - raise FileNotFoundError( - "Could not find the model weights to continue training" - ) - model = Spec2Pep().load_from_checkpoint(model_filename, **model_params) - # Create the Trainer object and (optionally) a checkpoint callback to - # periodically save the model. - if config["save_model"]: - callbacks = [ - pl.callbacks.ModelCheckpoint( - dirpath=config["model_save_folder_path"], - save_top_k=-1, - save_weights_only=config["save_weights_only"], - every_n_train_steps=config["every_n_train_steps"], - ) - ] - else: - callbacks = None - - trainer = pl.Trainer( - accelerator="auto", - auto_select_gpus=True, - callbacks=callbacks, - devices=_get_devices(config["no_gpu"]), - enable_checkpointing=config["save_model"], - logger=config["logger"], - max_epochs=config["max_epochs"], - num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=_get_strategy(), - val_check_interval=config["every_n_train_steps"], - ) - # Train the model. - trainer.fit( - model, train_loader.train_dataloader(), val_loader.val_dataloader() - ) - # Clean up temporary files. - tmp_dir.cleanup() + raise FileNotFoundError("Could not find the model weights file") + + self.model = Spec2Pep().load_from_checkpoint( + self.model_filename, + **model_params, + ) + + def initialize_data_module( + self, + train_index: Optional[AnnotatedSpectrumIndex] = None, + valid_index: Optional[AnnotatedSpectrumIndex] = None, + test_index: ( + Optional[Union[AnnotatedSpectrumIndex, SpectrumIndex]] + ) = None, + ) -> None: + """Initialize the data module + + Parameters + ---------- + train_index : AnnotatedSpectrumIndex, optional + A spectrum index for model training. + valid_index : AnnotatedSpectrumIndex, optional + A spectrum index for validation. + test_index : AnnotatedSpectrumIndex or SpectrumIndex, optional + A spectrum index for evaluation or inference. + """ + try: + n_devices = self.trainer.num_devices + train_bs = self.config.train_batch_size // n_devices + eval_bs = self.config.predict_batch_size // n_devices + except AttributeError as err: + raise RuntimeError("Please use `initialize_trainer()` first.") + + self.loaders = DeNovoDataModule( + train_index=train_index, + valid_index=valid_index, + test_index=test_index, + min_mz=self.config.min_mz, + max_mz=self.config.max_mz, + min_intensity=self.config.min_intensity, + remove_precursor_tol=self.config.remove_precursor_tol, + n_workers=self.config.n_workers, + train_batch_size=train_bs, + eval_batch_size=eval_bs, + ) + + def _get_index( + self, + peak_path: str, + annotated: bool, + msg: str = "", + ) -> Union[SpectrumIndex, AnnotatedSpectrumIndex]: + """Get the spectrum index. + + If the file is a SpectrumIndex, only one is allowed. Otherwise multiple + may be specified. + + Parameters + ---------- + peak_path : str + The peak file/directory to check. + annotated : bool + Are the spectra expected to be annotated? + msg : str, optional + A string to insert into the error message. + + Returns + ------- + SpectrumIndex or AnnotatedSpectrumIndex + The spectrum index for training, evaluation, or inference. + """ + ext = (".mgf", ".h5", ".hdf5") + if not annotated: + ext += (".mzml", ".mzxml") + + if msg and msg[-1] != " ": + msg += " " + + filenames = _get_peak_filenames(peak_path, ext) + if not filenames: + not_found_err = f"Cound not find {msg}peak files" + logger.error(not_found_err + " from %s", peak_path) + raise FileNotFoundError(not_found_err) + + is_index = any([Path(f).suffix in (".h5", ".hdf5") for f in filenames]) + if is_index: + if len(filenames) > 1: + h5_err = f"Multiple {msg}HDF5 spectrum indexes specified" + logger.error(h5_err) + raise ValueError(h5_err) + + index_fname, filenames = filenames[0], None + else: + index_fname = Path(self.tmp_dir.name) / f"{uuid.uuid4().hex}.hdf5" + + Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex + valid_charge = np.arange(1, self.config.max_charge + 1) + return Index(index_fname, filenames, valid_charge=valid_charge) + + def _get_strategy(self) -> Optional[DDPStrategy]: + """Get the strategy for the Trainer. + + The DDP strategy works best when multiple GPUs are used. It can work + for CPU-only, but definitely fails using MPS (the Apple Silicon chip) + due to Gloo. + + Returns + ------- + Optional[DDPStrategy] + The strategy parameter for the Trainer. + + """ + if self.config.accelerator in ("cpu", "mps"): + return "auto" + + if self.config.devices == 1: + return "auto" + + if torch.cuda.device_count() > 1: + return DDPStrategy(find_unused_parameters=False, static_graph=True) + + return "auto" def _get_peak_filenames( - path: str, supported_ext: Iterable[str] = (".mgf",) + paths: Iterable[str], supported_ext: Iterable[str] ) -> List[str]: """ Get all matching peak file names from the path pattern. @@ -333,65 +378,22 @@ def _get_peak_filenames( Parameters ---------- - path : str - The path pattern. + paths : Iterable[str] + The path pattern(s). supported_ext : Iterable[str] - Extensions of supported peak file formats. Default: MGF. + Extensions of supported peak file formats. Returns ------- List[str] The peak file names matching the path pattern. """ - path = os.path.expanduser(path) - path = os.path.expandvars(path) - return [ - os.path.abspath(fn) - for fn in glob.glob(path, recursive=True) - if os.path.splitext(fn.lower())[1] in supported_ext - ] - - -def _get_strategy() -> Optional[DDPStrategy]: - """ - Get the strategy for the Trainer. - - The DDP strategy works best when multiple GPUs are used. It can work for - CPU-only, but definitely fails using MPS (the Apple Silicon chip) due to - Gloo. - - Returns - ------- - Optional[DDPStrategy] - The strategy parameter for the Trainer. - """ - if torch.cuda.device_count() > 1: - return DDPStrategy(find_unused_parameters=False, static_graph=True) - - return None - - -def _get_devices(no_gpu: bool) -> Union[int, str]: - """ - Get the number of GPUs/CPUs for the Trainer to use. - - Parameters - ---------- - no_gpu : bool - If true, disable all GPU usage. - - Returns - ------- - Union[int, str] - The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning - determine the appropriate number of devices. - """ - if not no_gpu and any( - operator.attrgetter(device + ".is_available")(torch)() - for device in ("cuda",) - ): - return -1 - elif not (n_workers := utils.n_workers()): - return "auto" - else: - return n_workers + found_files = set() + for path in paths: + path = os.path.expanduser(path) + path = os.path.expandvars(path) + for fname in glob.glob(path, recursive=True): + if Path(fname).suffix.lower() in supported_ext: + found_files.add(fname) + + return sorted(list(found_files)) diff --git a/pyproject.toml b/pyproject.toml index 1b14d1e2..5efb2953 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,13 @@ dependencies = [ "pandas", "psutil", "PyGithub", - "pytorch-lightning>=1.7,<2.0", + "lightning>=2.0", "PyYAML", "requests", "scikit-learn", "spectrum_utils", "tensorboard", - "torch>=1.9", + "torch>=2.0", "tqdm", ] dynamic = ["version"] diff --git a/tests/conftest.py b/tests/conftest.py index 574aed8f..d987cf13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import numpy as np import psims import pytest +import yaml from pyteomics.mass import calculate_mass @@ -180,3 +181,25 @@ def _create_mzml(peptides, mzml_file, random_state=42): ) return mzml_file + + +@pytest.fixture +def tiny_config(tmp_path): + """A config file for a tiny model.""" + cfg = { + "n_head": 2, + "dim_feedfoward": 10, + "n_layers": 1, + "warmup_iters": 1, + "max_iters": 10, + "max_epochs": 10, + "every_n_train_steps": 1, + "model_save_folder_path": str(tmp_path), + "accelerator": "cpu", + } + + cfg_file = tmp_path / "config.yml" + with cfg_file.open("w+") as out_file: + yaml.dump(cfg, out_file) + + return cfg_file diff --git a/tests/test_integration.py b/tests/test_integration.py index 7d63f5c1..01c1b120 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,26 +1,76 @@ +import functools import pyteomics.mztab +from click.testing import CliRunner from casanovo import casanovo -def test_denovo(mgf_small, mzml_small, tmp_path, monkeypatch): +def test_train_and_run( + mgf_small, mzml_small, tiny_config, tmp_path, monkeypatch +): # We can use this to explicitly test different versions. monkeypatch.setattr(casanovo, "__version__", "3.0.1") - # Predict on small files (MGF and mzML) and verify that the output mzTab - # file exists. - output_filename = tmp_path / "test.mztab" - casanovo.main( - [ - "--mode", - "denovo", - "--peak_path", - str(mgf_small).replace(".mgf", ".m*"), - "--output", - str(output_filename), - ], - standalone_mode=False, + # Run a command: + run = functools.partial( + CliRunner().invoke, casanovo.main, catch_exceptions=False ) + + # Train a tiny model: + train_args = [ + "--mode", + "train", + "--peak_path", + mgf_small, + "--peak_path_val", + mgf_small, + "--config", + tiny_config, + "--output", + str(tmp_path / "train"), + ] + + result = run(train_args) + model_file = tmp_path / "epoch=9-step=10.ckpt" + assert result.exit_code == 0 + assert model_file.exists() + + # Try evaluating: + eval_args = [ + "--mode", + "eval", + "--peak_path", + mgf_small, + "--model", + model_file, + "--config", + tiny_config, + "--output", + str(tmp_path / "eval"), + ] + + result = run(eval_args) + assert result.exit_code == 0 + + # Finally try predicting: + output_filename = tmp_path / "test.mztab" + predict_args = [ + "--mode", + "denovo", + "--peak_path", + mgf_small, + "--peak_path", + mzml_small, + "--model", + model_file, + "--config", + tiny_config, + "--output", + str(output_filename), + ] + + result = run(predict_args) + assert result.exit_code == 0 assert output_filename.is_file() mztab = pyteomics.mztab.MzTab(str(output_filename)) @@ -29,7 +79,8 @@ def test_denovo(mgf_small, mzml_small, tmp_path, monkeypatch): assert f"ms_run[{i}]-location" in mztab.metadata assert mztab.metadata[f"ms_run[{i}]-location"].endswith(filename) - # Verify that the spectrum predictions are correct and indexed according to + # Verify that the spectrum predictions are correct + # and indexed according to # the peak input file type. psms = mztab.spectrum_match_table assert psms.loc[1, "sequence"] == "LESLLEK" diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 8282e367..8da26f8c 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -9,7 +9,7 @@ def test_default(): config = Config() assert config.random_seed == 454 assert config["random_seed"] == 454 - assert not config.no_gpu + assert config.accelerator == "auto" assert config.file == "default" @@ -22,6 +22,6 @@ def test_override(tmp_path): config = Config(yml) assert config.random_seed == 42 assert config["random_seed"] == 42 - assert not config.no_gpu + assert config.accelerator == "auto" assert config.top_match == 3 assert config.file == str(yml) From c5a891b348ed672063d5f83549f0116791c450f2 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:31:36 -0700 Subject: [PATCH 02/18] Update linting to only happen once --- .github/workflows/{black.yml => lint.yml} | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) rename .github/workflows/{black.yml => lint.yml} (87%) diff --git a/.github/workflows/black.yml b/.github/workflows/lint.yml similarity index 87% rename from .github/workflows/black.yml rename to .github/workflows/lint.yml index cec52e2b..ce576f53 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,10 @@ name: Lint -on: [push, pull_request] +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] jobs: lint: From c3841392bf41f21efd89818e32664035012b74ae Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:37:04 -0700 Subject: [PATCH 03/18] Fix linting error --- casanovo/denovo/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 09bc6ca4..e002335d 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -241,7 +241,7 @@ def initialize_model(self, train: bool) -> None: if not self.model_filename.exists(): logger.error( "Could not find the model weights at file %s", - model_filename, + self.model_filename, ) raise FileNotFoundError("Could not find the model weights file") From 2cd4cd6c72152c12748b8a4eb312442c9349885a Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:51:29 -0700 Subject: [PATCH 04/18] Specify utf-8 encoding --- casanovo/config.py | 2 +- tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/casanovo/config.py b/casanovo/config.py index bfff3685..58fdd235 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -74,7 +74,7 @@ class Config: def __init__(self, config_file: Optional[str] = None): """Initialize a Config object.""" self.file = str(config_file) if config_file is not None else "default" - with self._default_config.open() as f_in: + with self._default_config.open(encoding="utf-8") as f_in: self._params = yaml.safe_load(f_in) if config_file is None: diff --git a/tests/conftest.py b/tests/conftest.py index d987cf13..e1a15023 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def tiny_config(tmp_path): } cfg_file = tmp_path / "config.yml" - with cfg_file.open("w+") as out_file: + with cfg_file.open("w+", encoding="utf-8") as out_file: yaml.dump(cfg, out_file) return cfg_file From 1003c1679f703490fada98dc848e76c477251649 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:52:25 -0700 Subject: [PATCH 05/18] Specify utf-8 encoding only for default config --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index e1a15023..d987cf13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def tiny_config(tmp_path): } cfg_file = tmp_path / "config.yml" - with cfg_file.open("w+", encoding="utf-8") as out_file: + with cfg_file.open("w+") as out_file: yaml.dump(cfg, out_file) return cfg_file From 7b0c323a397f9515d693d29572a342a8d9f19bc1 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 00:59:56 -0700 Subject: [PATCH 06/18] Skip weights tests for now --- tests/unit_tests/test_unit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index bc0509bd..3ad09fca 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -69,6 +69,7 @@ def test_split_version(): assert version == ("3", "0", "1") +@pytest.skip(msg="Hit rate limit during CI/CD") def test_get_model_weights(monkeypatch): """ Test that model weights can be downloaded from GitHub or used from the From 4ab43e140708a1acb1bfb3af4b7003e806a7fb47 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 20:29:22 -0700 Subject: [PATCH 07/18] Update skipping API test --- tests/unit_tests/test_unit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 3ad09fca..fb906be6 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -69,7 +69,7 @@ def test_split_version(): assert version == ("3", "0", "1") -@pytest.skip(msg="Hit rate limit during CI/CD") +@pytest.mark.skip(msg="Hit rate limit during CI/CD") def test_get_model_weights(monkeypatch): """ Test that model weights can be downloaded from GitHub or used from the From 2858a909a21d2fd4300680d8e23cf5baf0ae59ad Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 20:34:59 -0700 Subject: [PATCH 08/18] Revert accidental max_epochs change --- casanovo/config.yaml | 2 +- tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 4f84d1e8..565c480b 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -106,7 +106,7 @@ top_match: 1 # Object for logging training progress logger: # Max number of training epochs -max_epochs: 1000 +max_epochs: 30 # Number of validation steps to run before training begins num_sanity_val_steps: 0 # Set to "False" to further train a pre-trained Casanovo model diff --git a/tests/conftest.py b/tests/conftest.py index d987cf13..c0d09703 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,7 +192,7 @@ def tiny_config(tmp_path): "n_layers": 1, "warmup_iters": 1, "max_iters": 10, - "max_epochs": 10, + "max_epochs": 30, "every_n_train_steps": 1, "model_save_folder_path": str(tmp_path), "accelerator": "cpu", From 593b2f67c9950fa2f69859e93cea371353509887 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Sat, 22 Apr 2023 20:40:15 -0700 Subject: [PATCH 09/18] msg -> reason for pytest.mark.skip --- tests/unit_tests/test_unit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index fb906be6..8fd0689b 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -69,7 +69,7 @@ def test_split_version(): assert version == ("3", "0", "1") -@pytest.mark.skip(msg="Hit rate limit during CI/CD") +@pytest.mark.skip(reason="Hit rate limit during CI/CD") def test_get_model_weights(monkeypatch): """ Test that model weights can be downloaded from GitHub or used from the From 1c117066a2276d9880d702892eb88887e264a6d0 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Fri, 28 Apr 2023 09:26:49 -0700 Subject: [PATCH 10/18] Wout's suggestions and more tests --- casanovo/casanovo.py | 1 - casanovo/config.yaml | 2 +- casanovo/denovo/dataloaders.py | 10 +++++--- casanovo/denovo/model_runner.py | 20 +++++++-------- tests/test_integration.py | 3 ++- tests/unit_tests/test_runner.py | 43 +++++++++++++++++++++++++++++++++ 6 files changed, 62 insertions(+), 17 deletions(-) create mode 100644 tests/unit_tests/test_runner.py diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 98bc986c..f07d7b43 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -165,7 +165,6 @@ def main( if mode == "denovo": logger.info("Predict peptide sequences with Casanovo.") model_runner.predict(peak_path, output) - model_runner.writer.save() elif mode == "eval": logger.info("Evaluate a trained Casanovo model.") model_runner.evaluate(peak_path) diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 565c480b..44eaa8f6 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -120,7 +120,7 @@ save_weights_only: True # Model validation and checkpointing frequency in training steps every_n_train_steps: 50_000 # The hardware accelerator to use. Must be one of: -# “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, or “auto” +# "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto" accelerator: "auto" # The devices to use. Can be set to a positive number int, # or the value -1 to indicate all available devices should be used, diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 15d21b07..7ab78355 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -3,9 +3,9 @@ import os from typing import List, Optional, Tuple -import torch -import numpy as np import lightning.pytorch as pl +import numpy as np +import torch from depthcharge.data import AnnotatedSpectrumIndex from ..data.datasets import AnnotatedSpectrumDataset, SpectrumDataset @@ -23,8 +23,10 @@ class DeNovoDataModule(pl.LightningDataModule): The spectrum index file corresponding to the validation data. test_index : Optional[AnnotatedSpectrumIndex] The spectrum index file corresponding to the testing data. - batch_size : int - The batch size to use for training and evaluating. + train_batch_size : int + The batch size to use for training. + eval_batch_size : int + The batch size to use for inference. n_peaks : Optional[int] The number of top-n most intense peaks to keep in each spectrum. `None` retains all peaks. diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index e002335d..dc866e48 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -9,11 +9,11 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union +import lightning.pytorch as pl import numpy as np import torch -import lightning.pytorch as pl -from lightning.pytorch.strategies import DDPStrategy from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex +from lightning.pytorch.strategies import DDPStrategy from .. import utils from ..config import Config @@ -75,6 +75,8 @@ def __exit__(self, exc_type, exc_value, traceback): """Cleanup on exit""" self.tmp_dir.cleanup() self.tmp_dir = None + if self.writer is not None: + self.writer.save() def train( self, @@ -175,7 +177,7 @@ def initialize_trainer(self, train: bool) -> None: logger=self.config.logger, ) - if self.train: + if train: if self.config.devices is None: devices = "auto" else: @@ -235,10 +237,10 @@ def initialize_model(self, train: bool) -> None: self.model = Spec2Pep(**model_params) return elif self.model_filename is None: - logger.error("A model file must be proided") + logger.error("A model file must be provided") raise ValueError("A model file must be provided") - if not self.model_filename.exists(): + if not Path(self.model_filename).exists(): logger.error( "Could not find the model weights at file %s", self.model_filename, @@ -318,19 +320,17 @@ def _get_index( if not annotated: ext += (".mzml", ".mzxml") - if msg and msg[-1] != " ": - msg += " " - + msg = msg.strip() filenames = _get_peak_filenames(peak_path, ext) if not filenames: - not_found_err = f"Cound not find {msg}peak files" + not_found_err = f"Cound not find {msg} peak files" logger.error(not_found_err + " from %s", peak_path) raise FileNotFoundError(not_found_err) is_index = any([Path(f).suffix in (".h5", ".hdf5") for f in filenames]) if is_index: if len(filenames) > 1: - h5_err = f"Multiple {msg}HDF5 spectrum indexes specified" + h5_err = f"Multiple {msg} HDF5 spectrum indexes specified" logger.error(h5_err) raise ValueError(h5_err) diff --git a/tests/test_integration.py b/tests/test_integration.py index 01c1b120..12002dc2 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,8 @@ import functools -import pyteomics.mztab +import pyteomics.mztab from click.testing import CliRunner + from casanovo import casanovo diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py new file mode 100644 index 00000000..bb66a09b --- /dev/null +++ b/tests/unit_tests/test_runner.py @@ -0,0 +1,43 @@ +"""Unit tests specifically for the model_runner module.""" +from pathlib import Path + +import pytest + +from casanovo.config import Config +from casanovo.denovo.model_runner import ModelRunner + + +def test_initialize_model(tmp_path): + """Test that""" + config = Config() + config.train_from_scratch = False + ModelRunner(config=config).initialize_model(train=True) + + with pytest.raises(ValueError): + ModelRunner(config=config).initialize_model(train=False) + + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=True) + + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=False) + + # This should work now: + config.train_from_scratch = True + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=True) + + # But this should still fail: + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=False) + + # If the model initialization throws and EOFError, then the Spec2Pep model + # has tried to load the weights: + weights = tmp_path / "blah" + weights.touch() + with pytest.raises(EOFError): + runner = ModelRunner(config=config, model_filename=str(weights)) + runner.initialize_model(train=False) From bc48f3006d156afa259d3cfd35c014742e93f017 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Fri, 28 Apr 2023 09:28:29 -0700 Subject: [PATCH 11/18] Remove encoding --- casanovo/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/casanovo/config.py b/casanovo/config.py index 58fdd235..bfff3685 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -74,7 +74,7 @@ class Config: def __init__(self, config_file: Optional[str] = None): """Initialize a Config object.""" self.file = str(config_file) if config_file is not None else "default" - with self._default_config.open(encoding="utf-8") as f_in: + with self._default_config.open() as f_in: self._params = yaml.safe_load(f_in) if config_file is None: From 842b6946caf4657659da27cdf472adfa42950092 Mon Sep 17 00:00:00 2001 From: melihyilmaz Date: Sun, 7 May 2023 18:49:19 -0700 Subject: [PATCH 12/18] Specify device type when weight loading --- casanovo/denovo/model_runner.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index dc866e48..856394dc 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -246,9 +246,23 @@ def initialize_model(self, train: bool) -> None: self.model_filename, ) raise FileNotFoundError("Could not find the model weights file") - + + accelerator_class = str(type(self.trainer.accelerator)) + if "CUDA" in accelerator_class: + map_location_device = "cuda" + elif "TPU" in accelerator_class: + map_location_device = "xla" + elif "HPU" in accelerator_class: + map_location_device = "hpu" + elif "IPU" in accelerator_class: + map_location_device = "ipu" + #FIXME: Handle the case for mps separately? + else: + map_location_device = "cpu" + self.model = Spec2Pep().load_from_checkpoint( self.model_filename, + map_location=torch.device(map_location_device), **model_params, ) From af6abfbed7cb43c002b243f49d939e98ddeae333 Mon Sep 17 00:00:00 2001 From: melihyilmaz Date: Sun, 7 May 2023 18:53:20 -0700 Subject: [PATCH 13/18] Fix lint --- casanovo/denovo/model_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 856394dc..f7b207d3 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -246,20 +246,20 @@ def initialize_model(self, train: bool) -> None: self.model_filename, ) raise FileNotFoundError("Could not find the model weights file") - + accelerator_class = str(type(self.trainer.accelerator)) if "CUDA" in accelerator_class: map_location_device = "cuda" elif "TPU" in accelerator_class: map_location_device = "xla" elif "HPU" in accelerator_class: - map_location_device = "hpu" + map_location_device = "hpu" elif "IPU" in accelerator_class: - map_location_device = "ipu" - #FIXME: Handle the case for mps separately? + map_location_device = "ipu" + # FIXME: Handle the case for mps separately. else: map_location_device = "cpu" - + self.model = Spec2Pep().load_from_checkpoint( self.model_filename, map_location=torch.device(map_location_device), From 57171a7957e1a8d42e296350c8e7f4869c64aa51 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Tue, 9 May 2023 23:00:39 -0700 Subject: [PATCH 14/18] Capture init params and figure out device automagically --- casanovo/config.py | 1 - casanovo/config.yaml | 2 -- casanovo/denovo/model.py | 1 + casanovo/denovo/model_runner.py | 21 +++------------------ 4 files changed, 4 insertions(+), 21 deletions(-) diff --git a/casanovo/config.py b/casanovo/config.py index bfff3685..a21842c7 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -65,7 +65,6 @@ class Config: train_from_scratch=bool, save_model=bool, model_save_folder_path=str, - save_weights_only=bool, every_n_train_steps=int, accelerator=str, devices=int, diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 44eaa8f6..952d8c78 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -115,8 +115,6 @@ train_from_scratch: True save_model: True # Path to saved checkpoints model_save_folder_path: "" -# Set to "False" to save the PyTorch model instance -save_weights_only: True # Model validation and checkpointing frequency in training steps every_n_train_steps: 50_000 # The hardware accelerator to use. Must be one of: diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index a387929b..1105a9e7 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -114,6 +114,7 @@ def __init__( **kwargs: Dict, ): super().__init__() + self.save_hyperparameters() # Build the model. if custom_encoder is not None: diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index f7b207d3..6e34302b 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -59,7 +59,6 @@ def __init__( pl.callbacks.ModelCheckpoint( dirpath=config.model_save_folder_path, save_top_k=-1, - save_weights_only=config.save_weights_only, every_n_train_steps=config.every_n_train_steps, ) ] @@ -247,23 +246,9 @@ def initialize_model(self, train: bool) -> None: ) raise FileNotFoundError("Could not find the model weights file") - accelerator_class = str(type(self.trainer.accelerator)) - if "CUDA" in accelerator_class: - map_location_device = "cuda" - elif "TPU" in accelerator_class: - map_location_device = "xla" - elif "HPU" in accelerator_class: - map_location_device = "hpu" - elif "IPU" in accelerator_class: - map_location_device = "ipu" - # FIXME: Handle the case for mps separately. - else: - map_location_device = "cpu" - - self.model = Spec2Pep().load_from_checkpoint( + self.model = Spec2Pep.load_from_checkpoint( self.model_filename, - map_location=torch.device(map_location_device), - **model_params, + map_location=torch.empty(1).device, # Use the default device. ) def initialize_data_module( @@ -289,7 +274,7 @@ def initialize_data_module( n_devices = self.trainer.num_devices train_bs = self.config.train_batch_size // n_devices eval_bs = self.config.predict_batch_size // n_devices - except AttributeError as err: + except AttributeError: raise RuntimeError("Please use `initialize_trainer()` first.") self.loaders = DeNovoDataModule( From 9cd2770d108131a5ad83a6010a235e6d36f276d9 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Tue, 9 May 2023 23:45:16 -0700 Subject: [PATCH 15/18] Add runner tests --- tests/unit_tests/test_runner.py | 38 ++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index bb66a09b..5de61369 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -1,7 +1,10 @@ """Unit tests specifically for the model_runner module.""" -from pathlib import Path +from typing import Union, Any, Dict +import lightning.pytorch as pl import pytest +import torch +from lightning.pytorch.accelerators import Accelerator from casanovo.config import Config from casanovo.denovo.model_runner import ModelRunner @@ -41,3 +44,36 @@ def test_initialize_model(tmp_path): with pytest.raises(EOFError): runner = ModelRunner(config=config, model_filename=str(weights)) runner.initialize_model(train=False) + + +def test_save_and_load_weights(tmp_path, mgf_small, tiny_config): + """Test saving aloading weights""" + config = Config(tiny_config) + config.max_epochs = 1 + config.n_layers = 1 + ckpt = tmp_path / "test.ckpt" + + with ModelRunner(config=config) as runner: + runner.train([mgf_small], [mgf_small]) + runner.trainer.save_checkpoint(ckpt) + + # Try changing model arch: + config.n_layers = 50 # lol + with torch.device("meta"): + # Now load the weights into a new model + # The device should be meta for all the weights. + runner = ModelRunner(config=config, model_filename=ckpt) + runner.initialize_model(train=False) + + obs_layers = runner.model.encoder.transformer_encoder.num_layers + assert obs_layers == 1 # Mach the original arch. + assert next(runner.model.parameters()).device == torch.device("meta") + + # If the Trainer correctly moves the weights to the accelerator, + # then it should fail if the weights are on the "meta" device. + with torch.device("meta"): + with ModelRunner(config=config, model_filename=ckpt) as runner: + with pytest.raises(NotImplementedError) as err: + runner.evaluate([mgf_small]) + + assert "meta tensor; no data!" in str(err.value) From 096aeb971bfd63d4e8e1dbb6f2fd40e408a6d007 Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Wed, 10 May 2023 00:14:23 -0700 Subject: [PATCH 16/18] Fix bug and limit saved models --- casanovo/config.py | 2 +- casanovo/config.yaml | 5 +++-- casanovo/denovo/model_runner.py | 10 ++++++---- tests/conftest.py | 4 ++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/casanovo/config.py b/casanovo/config.py index a21842c7..fbbf2e16 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -63,7 +63,7 @@ class Config: max_epochs=int, num_sanity_val_steps=int, train_from_scratch=bool, - save_model=bool, + save_top_k=int, model_save_folder_path=str, every_n_train_steps=int, accelerator=str, diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 952d8c78..7b8379ab 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -111,8 +111,9 @@ max_epochs: 30 num_sanity_val_steps: 0 # Set to "False" to further train a pre-trained Casanovo model train_from_scratch: True -# Save model checkpoints during training -save_model: True +# Save the top k model checkpoints during training. -1 saves all and +# leaving this field empty saves none. +save_top_k: 5 # Path to saved checkpoints model_save_folder_path: "" # Model validation and checkpointing frequency in training steps diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 6e34302b..4d2caeb0 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -54,11 +54,13 @@ def __init__( self.writer = None # Configure checkpoints. - if config.save_model: + if config.save_top_k is not None: self.callbacks = [ pl.callbacks.ModelCheckpoint( dirpath=config.model_save_folder_path, - save_top_k=-1, + monitor="valid_CELoss", + mode="min", + save_top_k=config.save_top_k, every_n_train_steps=config.every_n_train_steps, ) ] @@ -153,12 +155,12 @@ def predict(self, peak_path: Iterable[str], output: str) -> None: self.initialize_trainer(train=False) self.initialize_model(train=False) + self.model.out_writer = self.writer test_index = self._get_index(peak_path, False, "") self.writer.set_ms_run(test_index.ms_files) self.initialize_data_module(test_index=test_index) self.loaders.setup(stage="test", annotated=False) - self.trainer.predict(self.model, self.loaders.test_dataloader()) def initialize_trainer(self, train: bool) -> None: @@ -185,7 +187,7 @@ def initialize_trainer(self, train: bool) -> None: additional_cfg = dict( devices=devices, callbacks=self.callbacks, - enable_checkpointing=self.config.save_model, + enable_checkpointing=self.config.save_top_k is not None, max_epochs=self.config.max_epochs, num_sanity_val_steps=self.config.num_sanity_val_steps, strategy=self._get_strategy(), diff --git a/tests/conftest.py b/tests/conftest.py index c0d09703..c137c5f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -191,8 +191,8 @@ def tiny_config(tmp_path): "dim_feedfoward": 10, "n_layers": 1, "warmup_iters": 1, - "max_iters": 10, - "max_epochs": 30, + "max_iters": 1, + "max_epochs": 10, "every_n_train_steps": 1, "model_save_folder_path": str(tmp_path), "accelerator": "cpu", From fdd76386b3941963f52d444d9e9528a5357a708c Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Wed, 10 May 2023 14:14:42 -0700 Subject: [PATCH 17/18] Support old weights too --- casanovo/denovo/model_runner.py | 18 ++++++++++++++---- tests/unit_tests/test_runner.py | 23 +++++++++++++++++++---- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 4d2caeb0..5f9c9e98 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -248,10 +248,20 @@ def initialize_model(self, train: bool) -> None: ) raise FileNotFoundError("Could not find the model weights file") - self.model = Spec2Pep.load_from_checkpoint( - self.model_filename, - map_location=torch.empty(1).device, # Use the default device. - ) + # First try loading model details from the weithgs file, + # otherwise use the provided configuration. + device = torch.empty(1).device # Use the default device. + try: + self.model = Spec2Pep.load_from_checkpoint( + self.model_filename, + map_location=device, + ) + except RuntimeError: + self.model = Spec2Pep.load_from_checkpoint( + self.model_filename, + map_location=device, + **model_params, + ) def initialize_data_module( self, diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 5de61369..6ef2b250 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -58,22 +58,37 @@ def test_save_and_load_weights(tmp_path, mgf_small, tiny_config): runner.trainer.save_checkpoint(ckpt) # Try changing model arch: - config.n_layers = 50 # lol + other_config = Config(tiny_config) + other_config.n_layers = 50 # lol with torch.device("meta"): # Now load the weights into a new model # The device should be meta for all the weights. - runner = ModelRunner(config=config, model_filename=ckpt) + runner = ModelRunner(config=other_config, model_filename=ckpt) runner.initialize_model(train=False) obs_layers = runner.model.encoder.transformer_encoder.num_layers - assert obs_layers == 1 # Mach the original arch. + assert obs_layers == 1 # Match the original arch. assert next(runner.model.parameters()).device == torch.device("meta") # If the Trainer correctly moves the weights to the accelerator, # then it should fail if the weights are on the "meta" device. with torch.device("meta"): - with ModelRunner(config=config, model_filename=ckpt) as runner: + with ModelRunner(other_config, model_filename=ckpt) as runner: with pytest.raises(NotImplementedError) as err: runner.evaluate([mgf_small]) assert "meta tensor; no data!" in str(err.value) + + # Try without arch: + ckpt_data = torch.load(ckpt) + del ckpt_data["hyper_parameters"] + torch.save(ckpt_data, ckpt) + + # Shouldn't work: + with ModelRunner(other_config, model_filename=ckpt) as runner: + with pytest.raises(RuntimeError): + runner.evaluate([mgf_small]) + + # Should work: + with ModelRunner(config=config, model_filename=ckpt) as runner: + runner.evaluate([mgf_small]) From 8208b936bcc5c7bc7c610bae2c115b2a53eb222a Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Wed, 10 May 2023 14:31:27 -0700 Subject: [PATCH 18/18] Remove every_n_train_steps from checkpoint --- casanovo/denovo/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 5f9c9e98..2c22bd62 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -61,7 +61,6 @@ def __init__( monitor="valid_CELoss", mode="min", save_top_k=config.save_top_k, - every_n_train_steps=config.every_n_train_steps, ) ] else: