Skip to content

Commit

Permalink
Fix CPU bug, overhaul model runner, and update to lightning >=2.0 (#176)
Browse files Browse the repository at this point in the history
* Overhaul runner

* Update linting to only happen once

* Fix linting error

* Specify utf-8 encoding

* Specify utf-8 encoding only for default config

* Skip weights tests for now

* Update skipping API test

* Revert accidental max_epochs change

* msg -> reason for pytest.mark.skip

* Wout's suggestions and more tests

* Remove encoding

* Specify device type when weight loading

* Fix lint

* Capture init params and figure out device automagically

* Add runner tests

* Fix bug and limit saved models

* Support old weights too

* Remove every_n_train_steps from checkpoint

---------

Co-authored-by: melihyilmaz <[email protected]>
  • Loading branch information
wfondrie and melihyilmaz authored May 10, 2023
1 parent effc955 commit 6299bd2
Show file tree
Hide file tree
Showing 16 changed files with 614 additions and 419 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/black.yml → .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Lint

on: [push, pull_request]
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
lint:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Test stuff:
test_path/
lightning_logs/
envs/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
30 changes: 15 additions & 15 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import torch
import tqdm
import yaml
from pytorch_lightning.lite import LightningLite
from lightning.pytorch import seed_everything

from . import __version__
from . import utils
from .data import ms_io
from .denovo import model_runner
from .denovo import ModelRunner
from .config import Config

logger = logging.getLogger("casanovo")
Expand Down Expand Up @@ -52,11 +52,13 @@
required=True,
help="The file path with peak files for predicting peptide sequences or "
"training Casanovo.",
multiple=True,
)
@click.option(
"--peak_path_val",
help="The file path with peak files to be used as validation data during "
"training.",
multiple=True,
)
@click.option(
"--config",
Expand Down Expand Up @@ -127,7 +129,7 @@ def main(
# Read parameters from the config file.
config = Config(config)

LightningLite.seed_everything(seed=config["random_seed"], workers=True)
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
if model is None and mode != "train":
Expand Down Expand Up @@ -159,18 +161,16 @@ def main(
logger.debug("%s = %s", str(key), str(value))

# Run Casanovo in the specified mode.
if mode == "denovo":
logger.info("Predict peptide sequences with Casanovo.")
writer = ms_io.MztabWriter(f"{output}.mztab")
writer.set_metadata(config, model=model, config_filename=config.file)
model_runner.predict(peak_path, model, config, writer)
writer.save()
elif mode == "eval":
logger.info("Evaluate a trained Casanovo model.")
model_runner.evaluate(peak_path, model, config)
elif mode == "train":
logger.info("Train the Casanovo model.")
model_runner.train(peak_path, peak_path_val, model, config)
with ModelRunner(config, model) as model_runner:
if mode == "denovo":
logger.info("Predict peptide sequences with Casanovo.")
model_runner.predict(peak_path, output)
elif mode == "eval":
logger.info("Evaluate a trained Casanovo model.")
model_runner.evaluate(peak_path)
elif mode == "train":
logger.info("Train the Casanovo model.")
model_runner.train(peak_path, peak_path_val)


def _get_model_weights() -> str:
Expand Down
12 changes: 3 additions & 9 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class Config:
max_epochs=int,
num_sanity_val_steps=int,
train_from_scratch=bool,
save_model=bool,
save_top_k=int,
model_save_folder_path=str,
save_weights_only=bool,
every_n_train_steps=int,
no_gpu=bool,
accelerator=str,
devices=int,
)

def __init__(self, config_file: Optional[str] = None):
Expand All @@ -86,13 +86,7 @@ def __init__(self, config_file: Optional[str] = None):
for key, val in self._config_types.items():
self.validate_param(key, val)

# Add extra configuration options and scale by the number of GPUs.
n_gpus = 0 if self["no_gpu"] else torch.cuda.device_count()
self._params["n_workers"] = utils.n_workers()
if n_gpus > 1:
self._params["train_batch_size"] = (
self["train_batch_size"] // n_gpus
)

def __getitem__(self, param: str) -> Union[int, bool, str, Tuple, Dict]:
"""Retrieve a parameter"""
Expand Down
17 changes: 11 additions & 6 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,18 @@ max_epochs: 30
num_sanity_val_steps: 0
# Set to "False" to further train a pre-trained Casanovo model
train_from_scratch: True
# Save model checkpoints during training
save_model: True
# Save the top k model checkpoints during training. -1 saves all and
# leaving this field empty saves none.
save_top_k: 5
# Path to saved checkpoints
model_save_folder_path: ""
# Set to "False" to save the PyTorch model instance
save_weights_only: True
# Model validation and checkpointing frequency in training steps
every_n_train_steps: 50_000
# Disable usage of a GPU (including Apple MPS):
no_gpu: False
# The hardware accelerator to use. Must be one of:
# "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto"
accelerator: "auto"
# The devices to use. Can be set to a positive number int,
# or the value -1 to indicate all available devices should be used,
# If left empty, the appropriate number will be automatically
# selected for automatic selected on the chosen accelerator.
devices:
1 change: 1 addition & 0 deletions casanovo/denovo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model_runner import ModelRunner
30 changes: 19 additions & 11 deletions casanovo/denovo/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
from typing import List, Optional, Tuple

import lightning.pytorch as pl
import numpy as np
import pytorch_lightning as pl
import torch
from depthcharge.data import AnnotatedSpectrumIndex

Expand All @@ -23,8 +23,10 @@ class DeNovoDataModule(pl.LightningDataModule):
The spectrum index file corresponding to the validation data.
test_index : Optional[AnnotatedSpectrumIndex]
The spectrum index file corresponding to the testing data.
batch_size : int
The batch size to use for training and evaluating.
train_batch_size : int
The batch size to use for training.
eval_batch_size : int
The batch size to use for inference.
n_peaks : Optional[int]
The number of top-n most intense peaks to keep in each spectrum. `None`
retains all peaks.
Expand Down Expand Up @@ -52,7 +54,8 @@ def __init__(
train_index: Optional[AnnotatedSpectrumIndex] = None,
valid_index: Optional[AnnotatedSpectrumIndex] = None,
test_index: Optional[AnnotatedSpectrumIndex] = None,
batch_size: int = 128,
train_batch_size: int = 128,
eval_batch_size: int = 1028,
n_peaks: Optional[int] = 150,
min_mz: float = 50.0,
max_mz: float = 2500.0,
Expand All @@ -65,7 +68,8 @@ def __init__(
self.train_index = train_index
self.valid_index = valid_index
self.test_index = test_index
self.batch_size = batch_size
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.n_peaks = n_peaks
self.min_mz = min_mz
self.max_mz = max_mz
Expand Down Expand Up @@ -119,7 +123,9 @@ def setup(self, stage: str = None, annotated: bool = True) -> None:
self.test_dataset = make_dataset(self.test_index)

def _make_loader(
self, dataset: torch.utils.data.Dataset
self,
dataset: torch.utils.data.Dataset,
batch_size: int,
) -> torch.utils.data.DataLoader:
"""
Create a PyTorch DataLoader.
Expand All @@ -128,6 +134,8 @@ def _make_loader(
----------
dataset : torch.utils.data.Dataset
A PyTorch Dataset.
batch_size : int
The batch size to use.
Returns
-------
Expand All @@ -136,27 +144,27 @@ def _make_loader(
"""
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
batch_size=batch_size,
collate_fn=prepare_batch,
pin_memory=True,
num_workers=self.n_workers,
)

def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Get the training DataLoader."""
return self._make_loader(self.train_dataset)
return self._make_loader(self.train_dataset, self.train_batch_size)

def val_dataloader(self) -> torch.utils.data.DataLoader:
"""Get the validation DataLoader."""
return self._make_loader(self.valid_dataset)
return self._make_loader(self.valid_dataset, self.eval_batch_size)

def test_dataloader(self) -> torch.utils.data.DataLoader:
"""Get the test DataLoader."""
return self._make_loader(self.test_dataset)
return self._make_loader(self.test_dataset, self.eval_batch_size)

def predict_dataloader(self) -> torch.utils.data.DataLoader:
"""Get the predict DataLoader."""
return self._make_loader(self.test_dataset)
return self._make_loader(self.test_dataset, self.eval_batch_size)


def prepare_batch(
Expand Down
33 changes: 16 additions & 17 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import depthcharge.masses
import einops
import numpy as np
import pytorch_lightning as pl
import torch
import numpy as np
import lightning.pytorch as pl
from torch.utils.tensorboard import SummaryWriter
from depthcharge.components import ModelMixin, PeptideDecoder, SpectrumEncoder

Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(
**kwargs: Dict,
):
super().__init__()
self.save_hyperparameters()

# Build the model.
if custom_encoder is not None:
Expand Down Expand Up @@ -724,8 +725,8 @@ def training_step(
pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1)
loss = self.celoss(pred, truth.flatten())
self.log(
"CELoss",
{mode: loss.detach()},
f"{mode}_CELoss",
loss.detach(),
on_step=False,
on_epoch=True,
sync_dist=True,
Expand Down Expand Up @@ -766,12 +767,10 @@ def validation_step(
log_args = dict(on_step=False, on_epoch=True, sync_dist=True)
self.log(
"Peptide precision at coverage=1",
{"valid": pep_precision},
pep_precision,
**log_args,
)
self.log(
"AA precision at coverage=1", {"valid": aa_precision}, **log_args
)
self.log("AA precision at coverage=1", aa_precision, **log_args)

return loss

Expand Down Expand Up @@ -824,7 +823,7 @@ def on_train_epoch_end(self) -> None:
"""
Log the training loss at the end of each epoch.
"""
train_loss = self.trainer.callback_metrics["CELoss"]["train"].detach()
train_loss = self.trainer.callback_metrics["train_CELoss"].detach()
metrics = {
"step": self.trainer.global_step,
"train": train_loss,
Expand All @@ -839,19 +838,21 @@ def on_validation_epoch_end(self) -> None:
callback_metrics = self.trainer.callback_metrics
metrics = {
"step": self.trainer.global_step,
"valid": callback_metrics["CELoss"]["valid"].detach(),
"valid": callback_metrics["valid_CELoss"].detach(),
"valid_aa_precision": callback_metrics[
"AA precision at coverage=1"
]["valid"].detach(),
].detach(),
"valid_pep_precision": callback_metrics[
"Peptide precision at coverage=1"
]["valid"].detach(),
].detach(),
}
self._history.append(metrics)
self._log_history()

def on_predict_epoch_end(
self, results: List[List[Tuple[np.ndarray, List[str], torch.Tensor]]]
def on_predict_batch_end(
self,
outputs: List[Tuple[np.ndarray, List[str], torch.Tensor]],
*args,
) -> None:
"""
Write the predicted peptide sequences and amino acid scores to the
Expand All @@ -867,9 +868,7 @@ def on_predict_epoch_end(
peptide,
peptide_score,
aa_scores,
) in itertools.chain.from_iterable(
itertools.chain.from_iterable(results)
):
) in outputs:
if len(peptide) == 0:
continue
self.out_writer.psms.append(
Expand Down
Loading

0 comments on commit 6299bd2

Please sign in to comment.