Skip to content

Commit

Permalink
integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 30, 2024
1 parent 723fbb8 commit f1d6dfc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
3 changes: 3 additions & 0 deletions casanovo/data/ms_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def save(self) -> None:
):
filename = psm.spectrum_id[0]
idx = psm.spectrum_id[1]
if Path(filename).suffix == ".mgf" and idx.isnumeric():
idx = f"index={idx}"

writer.writerow(
[
"PSM",
Expand Down
4 changes: 1 addition & 3 deletions casanovo/denovo/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def __init__(
),
pa.string(),
),
CustomField(
"title", lambda x: f"index={x['params']['title']}", pa.string()
),
CustomField("title", lambda x: x["params"]["title"], pa.string()),
]
self.custom_field_test_mzml = [
CustomField("scans", lambda x: x["id"], pa.string()),
Expand Down
2 changes: 1 addition & 1 deletion casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ def on_predict_batch_end(
peptide_score=peptide_score,
charge=int(charge),
calc_mz=precursor_mz,
exp_mz=calc_mass,
exp_mz=calc_mass.item(),
aa_scores=aa_scores,
)
)
Expand Down
33 changes: 20 additions & 13 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import DataLoader

from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.tokenizers.peptides import MskbPeptideTokenizer
Expand Down Expand Up @@ -160,9 +161,7 @@ def train(
self.loaders.val_dataloader(),
)

def log_metrics(
self, test_dataloader: torch.utils.data.DataLoader
) -> None:
def log_metrics(self, test_dataloader: DataLoader) -> None:
"""Log peptide precision and amino acid precision
Calculate and log peptide precision and amino acid precision
Expand All @@ -178,12 +177,19 @@ def log_metrics(
seq_true = []
pred_idx = 0

with test_index as t_ind:
for true_idx in range(t_ind.n_spectra):
seq_true.append(t_ind[true_idx][4])
if pred_idx < len(self.writer.psms) and self.writer.psms[
pred_idx
].spectrum_id == t_ind.get_spectrum_id(true_idx):
for batch in test_dataloader:
for peak_file, scan_id, curr_seq_true in zip(
batch["peak_file"],
batch["scan_id"],
self.model.tokenizer.detokenize(batch["seq"][0]),
):
spectrum_id_true = (peak_file, scan_id)
seq_true.append(curr_seq_true)
if (
pred_idx < len(self.writer.psms)
and self.writer.psms[pred_idx].spectrum_id
== spectrum_id_true
):
seq_pred.append(self.writer.psms[pred_idx].sequence)
pred_idx += 1
else:
Expand All @@ -193,7 +199,7 @@ def log_metrics(
*aa_match_batch(
seq_true,
seq_pred,
depthcharge.masses.PeptideMass().masses,
self.model.tokenizer.residues,
)
)

Expand Down Expand Up @@ -249,11 +255,12 @@ def predict(
test_paths = self._get_input_paths(peak_path, evaluate, "test")
self.writer.set_ms_run(test_paths)
self.initialize_data_module(test_paths=test_paths)
self.loaders.setup(stage="test", annotated=False)
self.trainer.predict(self.model, self.loaders.test_dataloader())
self.loaders.setup(stage="test", annotated=evaluate)
predict_dataloader = self.loaders.predict_dataloader()
self.trainer.predict(self.model, predict_dataloader)

if evaluate:
self.log_metrics(self.loaders.test_dataloader())
self.log_metrics(predict_dataloader)

def initialize_trainer(self, train: bool) -> None:
"""Initialize the lightning Trainer.
Expand Down

0 comments on commit f1d6dfc

Please sign in to comment.