diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d6422fa..e464d7a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Update PyTorch Lightning global seed setting. +- Use beam search decoding rather than greedy decoding to predict the peptides. ### Fixed diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 8c07cc6e..66d9016e 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -156,6 +156,7 @@ def main( weight_decay=float, train_batch_size=int, predict_batch_size=int, + n_beams=int, max_epochs=int, num_sanity_val_steps=int, train_from_scratch=bool, diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 96652f02..71332033 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -65,6 +65,7 @@ weight_decay: 1e-5 # Training/inference options. train_batch_size: 32 predict_batch_size: 1024 +n_beams: 5 logger: max_epochs: 30 diff --git a/casanovo/denovo/evaluate.py b/casanovo/denovo/evaluate.py index acfab8eb..e4910f73 100644 --- a/casanovo/denovo/evaluate.py +++ b/casanovo/denovo/evaluate.py @@ -1,6 +1,6 @@ """Methods to evaluate peptide-spectrum predictions.""" import re -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Tuple import numpy as np from spectrum_utils.utils import mass_diff @@ -182,8 +182,8 @@ def aa_match( def aa_match_batch( - peptides1: List[str], - peptides2: List[str], + peptides1: Iterable, + peptides2: Iterable, aa_dict: Dict[str, float], cum_mass_threshold: float = 0.5, ind_mass_threshold: float = 0.1, @@ -194,10 +194,10 @@ def aa_match_batch( Parameters ---------- - peptides1 : List[str] - The first list of (untokenized) peptide sequences to be compared. - peptides2 : List[str] - The second list of (untokenized) peptide sequences to be compared. + peptides1 : Iterable + The first list of peptide sequences to be compared. + peptides2 : Iterable + The second list of peptide sequences to be compared. aa_dict : Dict[str, float] Mapping of amino acid tokens to their mass values. cum_mass_threshold : float @@ -221,13 +221,16 @@ def aa_match_batch( """ aa_matches_batch, n_aa1, n_aa2 = [], 0, 0 for peptide1, peptide2 in zip(peptides1, peptides2): - tokens1 = re.split(r"(?<=.)(?=[A-Z])", peptide1) - tokens2 = re.split(r"(?<=.)(?=[A-Z])", peptide2) - n_aa1, n_aa2 = n_aa1 + len(tokens1), n_aa2 + len(tokens2) + # Split peptides into individual AAs if necessary. + if isinstance(peptide1, str): + peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1) + if isinstance(peptide2, str): + peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2) + n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2) aa_matches_batch.append( aa_match( - tokens1, - tokens2, + peptide1, + peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 934169f7..23a12968 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1,9 +1,11 @@ """A de novo peptide sequencing model.""" +import heapq import logging -import re -from typing import Any, Dict, List, Optional, Tuple, Union +import operator +from typing import Any, Dict, List, Optional, Set, Tuple, Union import depthcharge.masses +import einops import numpy as np import pytorch_lightning as pl import torch @@ -13,7 +15,6 @@ from . import evaluate from ..data import ms_io - logger = logging.getLogger("casanovo") @@ -61,18 +62,22 @@ class Spec2Pep(pl.LightningModule, ModelMixin): isotope_error_range : Tuple[int, int] Take into account the error introduced by choosing a non-monoisotopic peak for fragmentation by not penalizing predicted precursor m/z's that - fit the specified isotope error: `abs(calc_mz - (precursor_mz - isotope * 1.00335 / precursor_charge)) < precursor_mass_tol` + fit the specified isotope error: + `abs(calc_mz - (precursor_mz - isotope * 1.00335 / precursor_charge)) + < precursor_mass_tol` + n_beams: int + Number of beams used during beam search decoding. n_log : int The number of epochs to wait between logging messages. tb_summarywriter: Optional[str] - Folder path to record performance metrics during training. If ``None``, don't - use a ``SummaryWriter``. + Folder path to record performance metrics during training. If ``None``, + don't use a ``SummaryWriter``. warmup_iters: int The number of warm up iterations for the learning rate scheduler. max_iters: int The total number of iterations for the learning rate scheduler. - out_filename: Optional[str] - The output file name for the prediction results. + out_writer: Optional[str] + The output writer for the prediction results. **kwargs : Dict Additional keyword arguments passed to the Adam optimizer. """ @@ -91,6 +96,7 @@ def __init__( max_charge: int = 5, precursor_mass_tol: float = 50, isotope_error_range: Tuple[int, int] = (0, 1), + n_beams: int = 5, n_log: int = 10, tb_summarywriter: Optional[ torch.utils.tensorboard.SummaryWriter @@ -135,6 +141,7 @@ def __init__( self.residues = residues self.precursor_mass_tol = precursor_mass_tol self.isotope_error_range = isotope_error_range + self.n_beams = n_beams self.peptide_mass_calculator = depthcharge.masses.PeptideMass( self.residues ) @@ -153,7 +160,7 @@ def __init__( def forward( self, spectra: torch.Tensor, precursors: torch.Tensor - ) -> Tuple[List[str], torch.Tensor]: + ) -> Tuple[List[List[str]], torch.Tensor]: """ Predict peptide sequences for a batch of MS/MS spectra. @@ -171,22 +178,26 @@ def forward( Returns ------- - peptides : List[str] + peptides : List[List[str]] The predicted peptide sequences for each spectrum. aa_scores : torch.Tensor of shape (n_spectra, length, n_amino_acids) The individual amino acid scores for each prediction. """ - aa_scores, tokens = self.greedy_decode( + aa_scores, tokens = self.beam_search_decode( spectra.to(self.encoder.device), precursors.to(self.decoder.device), ) + return [self.decoder.detokenize(t) for t in tokens], aa_scores - def greedy_decode( + def beam_search_decode( self, spectra: torch.Tensor, precursors: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Greedy decoding of the spectrum predictions. + Beam search decoding of the spectrum predictions. + + Return the highest scoring peptide, within the precursor m/z tolerance + whenever possible. Parameters ---------- @@ -204,35 +215,491 @@ def greedy_decode( ------- scores : torch.Tensor of shape (n_spectra, max_length, n_amino_acids) The individual amino acid scores for each prediction. - tokens : torch.Tensor of shape (n_spectra, max_length, n_amino_acids) + tokens : torch.Tensor of shape (n_spectra, max_length) The predicted tokens for each spectrum. """ memories, mem_masks = self.encoder(spectra) - # Initialize the scores. - scores = torch.zeros( - spectra.shape[0], self.max_length + 1, self.decoder.vocab_size + 1 - ).type_as(spectra) - # Start with the first amino acid predictions. - scores[:, :1, :], _ = self.decoder( - None, precursors, memories, mem_masks + + # Sizes. + batch = spectra.shape[0] # B + length = self.max_length + 1 # L + vocab = self.decoder.vocab_size + 1 # V + beam = self.n_beams # S + + # Initialize scores and tokens. + scores = torch.full( + size=(batch, length, vocab, beam), fill_value=torch.nan ) - tokens = torch.argmax(scores, axis=2) - # Keep predicting until a stop token is predicted or max_length is - # reached. - # The stop token does not count towards max_length. - for i in range(2, self.max_length + 2): + scores = scores.type_as(spectra) + tokens = torch.zeros(batch, length, beam, dtype=torch.int64) + tokens = tokens.to(self.encoder.device) + # Keep track whether terminated beams have fitting precursor m/z. + beam_fits_prec_tol = torch.zeros(batch * beam, dtype=torch.bool) + + # Create cache for decoded beams. + ( + cache_scores, + cache_tokens, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + ) = self._create_beamsearch_cache(scores, tokens) + + # Get the first prediction. + pred, _ = self.decoder(None, precursors, memories, mem_masks) + tokens[:, 0, :] = torch.topk(pred[:, 0, :], beam, dim=1)[1] + scores[:, :1, :, :] = einops.repeat(pred, "B L V -> B L V S", S=beam) + + # Make all tensors the right shape for decoding. + precursors = einops.repeat(precursors, "B L -> (B S) L", S=beam) + mem_masks = einops.repeat(mem_masks, "B L -> (B S) L", S=beam) + memories = einops.repeat(memories, "B L V -> (B S) L V", S=beam) + scores = einops.rearrange(scores, "B L V S -> (B S) L V") + tokens = einops.rearrange(tokens, "B L S -> (B S) L") + + # The main decoding loop. + for i in range(1, self.max_length + 1): + # Terminate beams exceeding precursor m/z tolerance and track all + # terminated beams. + finished_beams_idx, tokens = self._terminate_finished_beams( + tokens, precursors, beam_fits_prec_tol, i + ) + # Cache terminated beams, group and order by fitting precursor m/z + # and confidence score. + self._cache_finished_beams( + finished_beams_idx, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + cache_tokens, + cache_scores, + tokens, + scores, + beam_fits_prec_tol, + i, + ) + # Reset precursor tolerance status of all beams. + beam_fits_prec_tol = torch.zeros(batch * beam, dtype=torch.bool) + + # Stop decoding when all current beams are terminated. decoded = (tokens == self.stop_token).any(axis=1) if decoded.all(): break - scores[~decoded, :i, :], _ = self.decoder( - tokens[~decoded, : (i - 1)], + # Update the scores. + scores[~decoded, : i + 1, :], _ = self.decoder( + tokens[~decoded, :i], precursors[~decoded, :], memories[~decoded, :, :], mem_masks[~decoded, :], ) - tokens = torch.argmax(scores, axis=2) + # Find top-k beams with highest scores and continue decoding those. + scores, tokens = self._get_topk_beams(scores, tokens, batch, i) + + # Return the peptide with the highest confidence score, within the + # precursor m/z tolerance if possible. + output_tokens, output_scores = self._get_top_peptide( + cache_pred_score, cache_tokens, cache_scores, batch + ) + return self.softmax(output_scores), output_tokens + + def _create_beamsearch_cache( + self, scores: torch.Tensor, tokens: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Dict[int, int], + Dict[int, Set[str]], + Dict[int, List[List[Tuple[float, int]]]], + ]: + """ + Create cache tensor and dictionary to store and group terminated beams. - return self.softmax(scores), tokens + Parameters + ---------- + scores : torch.Tensor of shape + (n_spectra, max_length, n_amino_acids, n_beams) + Output scores of the model. + tokens : torch.Tensor of size (n_spectra, max_length, n_beams) + Output token of the model corresponding to amino acid sequences. + + Returns + ------- + cache_scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + The score for each amino acid in cached peptides. + cache_tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + The token for each amino acid in cached peptides. + cache_next_idx : Dict[int, int] + Next available tensor index to cache peptides for each spectrum. + cache_pred_seq : Dict[int, Set[torch.Tensor]] + Set of decoded peptide tokens for each spectrum. + cache_pred_score : Dict[int, List[List[Tuple[float, int]]] + Confidence score for each decoded peptide, separated as + precursor m/z fitting vs not, for each spectrum. + """ + batch, beam = scores.shape[0], scores.shape[-1] + + # Cache terminated beams and their scores. + cache_scores = einops.rearrange(scores.clone(), "B L V S -> (B S) L V") + cache_tokens = einops.rearrange(tokens.clone(), "B L S -> (B S) L") + + # Keep pointer to free rows in the cache and already cached predictions. + cache_next_idx = {i: i * beam for i in range(batch)} + # Keep already decoded peptides to avoid duplicates in cache. + cache_pred_seq = {i: set() for i in range(batch)} + # Store peptide scores to replace lower score peptides in cache with + # higher score peptides during decoding. + cache_pred_score = {i: [[], []] for i in range(batch)} + + return ( + cache_scores, + cache_tokens, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + ) + + def _terminate_finished_beams( + self, + tokens: torch.Tensor, + precursors: torch.Tensor, + is_beam_prec_fit: torch.Tensor, + idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Terminate beams exceeding the precursor m/z tolerance. + + Track all terminated beams. + + Parameters + ---------- + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + Output token of the model corresponding to amino acid sequences. + precursors : torch.Tensor of size (n_spectra * n_beams, 3) + The measured precursor mass (axis 0), precursor charge (axis 1), and + precursor m/z (axis 2) of each MS/MS spectrum. + is_beam_prec_fit: torch.Tensor of shape (n_spectra * n_beams) + Boolean tensor indicating if current beams are within precursor m/z + tolerance. + idx : int + Index to be considered in the current decoding step. + + Returns + ------- + finished_beams_idx : torch.Tensor + Indices of all finished beams on tokens tensor. + tokens : torch.Tensor of size (n_spectra * n_beams, max_length) + Output token of the model corresponding to amino acid sequences. + """ + # Check for tokens with a negative mass (i.e. neutral loss). + aa_neg_mass = [None] + for aa, mass in self.peptide_mass_calculator.masses.items(): + if mass < 0: + aa_neg_mass.append(aa) + + # Terminate beams that exceed the precursor m/z. + for beam_i in range(len(tokens)): + # Check only non-terminated beams. + if self.stop_token not in tokens[beam_i]: + # Finish if dummy was predicted at the previous step. + if tokens[beam_i][idx - 1] == 0: + tokens[beam_i][idx - 1] = self.stop_token + # Terminate the beam if it exceeds the precursor m/z tolerance. + else: + precursor_charge = precursors[beam_i, 1].item() + precursor_mz = precursors[beam_i, 2].item() + # Only terminate if the m/z difference cannot be corrected + # anymore by a subsequently predicted AA with negative mass. + matches_precursor_mz = exceeds_precursor_mz = False + for aa in aa_neg_mass: + peptide = self.decoder.detokenize(tokens[beam_i][:idx]) + if aa is not None: + peptide.append(aa) + try: + calc_mz = self.peptide_mass_calculator.mass( + seq=peptide, charge=precursor_charge + ) + delta_mass_ppm = [ + _calc_mass_error( + calc_mz, + precursor_mz, + precursor_charge, + isotope, + ) + for isotope in range( + self.isotope_error_range[0], + self.isotope_error_range[1] + 1, + ) + ] + # Terminate the beam if the calculated m/z for the + # predicted peptide (without potential additional + # AAs with negative mass) is within the precursor + # m/z tolerance. + matches_precursor_mz = aa is None and any( + abs(d) < self.precursor_mass_tol + for d in delta_mass_ppm + ) + # Terminate the beam if the calculated m/z exceeds + # the precursor m/z + tolerance and hasn't been + # corrected by a subsequently predicted AA with + # negative mass. + exceeds_precursor_mz = aa is not None and all( + d > self.precursor_mass_tol + for d in delta_mass_ppm + ) + if matches_precursor_mz or exceeds_precursor_mz: + break + except KeyError: + matches_precursor_mz = exceeds_precursor_mz = False + if matches_precursor_mz or exceeds_precursor_mz: + tokens[beam_i][idx] = self.stop_token + is_beam_prec_fit[beam_i] = matches_precursor_mz + + # Get the indices of finished beams. + finished_idx = torch.where((tokens == self.stop_token).any(dim=1))[0] + return finished_idx, tokens + + def _cache_finished_beams( + self, + finished_beams_idx: torch.Tensor, + cache_next_idx: Dict[int, int], + cache_pred_seq: Dict[int, Set[torch.Tensor]], + cache_pred_score: Dict[int, List[List[Tuple[float, int]]]], + cache_tokens: torch.Tensor, + cache_scores: torch.Tensor, + tokens: torch.Tensor, + scores: torch.Tensor, + is_beam_prec_fit: torch.Tensor, + idx: int, + ): + """ + Cache terminated beams. + + Group and order by fitting precursor m/z and confidence score. + + Parameters + ---------- + finished_beams_idx : torch.Tensor + Indices of all finished beams on tokens tensor. + cache_next_idx : Dict[int, int] + Next available tensor index to cache peptides for each spectrum. + cache_pred_seq : Dict[int, Set[torch.Tensor]] + Set of decoded peptide tokens for each spectrum. + cache_pred_score : Dict[int, List[List[Tuple[float, int]]] + Confidence score for each decoded peptide, separated as + precursor m/z fitting vs not, for each spectrum. + cache_tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + The token for each amino acid in cached peptides. + cache_scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + The score for each amino acid in cached peptides. + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + Output token of the model corresponding to amino acid sequences. + scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + Output scores of the model. + is_beam_prec_fit: torch.Tensor of shape (n_spectra * n_beams) + Boolean tensor indicating if current beams are within the precursor + m/z tolerance. + idx : int + Index to be considered in the current decoding step. + """ + beam = self.n_beams + # Store finished beams in the cache. + for i in finished_beams_idx: + i = i.item() + spec_idx = i // beam # Find the starting index of the spectrum. + # Check position of stop token (changes in case stopped early). + stop_token_idx = idx - (not tokens[i][idx] == self.stop_token) + # Check if predicted peptide already in cache. + pred_seq = tokens[i][:stop_token_idx] + is_peptide_cached = any( + torch.equal(pep, pred_seq) for pep in cache_pred_seq[spec_idx] + ) + # Don't cache this peptide if it was already predicted previously. + if is_peptide_cached: + continue + smx = self.softmax(scores) + aa_scores = [smx[i, j, k].item() for j, k in enumerate(pred_seq)] + pep_score = _aa_to_pep_score(aa_scores) + # Cache peptides with fitting (idx=0) or non-fitting (idx=1) + # precursor m/z separately. + cache_pred_score_idx = cache_pred_score[spec_idx] + cache_i = int(not is_beam_prec_fit[i]) + # Directly cache if we don't already have k peptides cached. + if cache_next_idx[spec_idx] < (spec_idx + 1) * beam: + insert_idx = cache_next_idx[spec_idx] + cache_next_idx[spec_idx] += 1 # Move the pointer. + heap_update = heapq.heappush + # If any prediction has a non-fitting precursor m/z and this + # prediction has a fitting precursor m/z, replace the non-fitting + # peptide with the lowest score, irrespective of the current + # predicted score. + elif is_beam_prec_fit[i] and len(cache_pred_score_idx[1]) > 0: + _, insert_idx = heapq.heappop(cache_pred_score_idx[1]) + heap_update = heapq.heappush + # Else, replace the lowest-scoring peptide with corresponding + # fitting or non-fitting precursor m/z if the current predicted + # score is higher. + elif len(cache_pred_score_idx[cache_i]) > 0: + # Peek at the top of the heap (lowest score). + pop_pep_score, insert_idx = cache_pred_score_idx[cache_i][0] + heap_update = heapq.heappushpop + # Don't store this prediction if it has a lower score than all + # previous predictions. + if pep_score <= pop_pep_score: + continue + # Finally, no matching cache found (we should never get here). + else: + continue + # Store the current prediction in its relevant cache. + cache_tokens[insert_idx, :] = tokens[i, :] + cache_scores[insert_idx, :, :] = scores[i, :, :] + heap_update(cache_pred_score_idx[cache_i], (pep_score, insert_idx)) + cache_pred_seq[spec_idx].add(pred_seq) + + def _get_top_peptide( + self, + cache_pred_score: Dict[int, List[List[Tuple[float, int]]]], + cache_tokens: torch.tensor, + cache_scores: torch.tensor, + batch: int, + ) -> Tuple[torch.tensor, torch.tensor]: + """ + Return the peptide with the highest confidence score for each spectrum. + + If there are no peptides within the precursor m/z tolerance, return the + highest-scoring peptide among the non-fitting predictions. + + Parameters + ---------- + cache_pred_score : Dict[int, List[List[Tuple[float, int]]] + Confidence score for each decoded peptide, separated as + precursor m/z fitting vs not, for each spectrum. + cache_tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + The token for each amino acid in cached peptides. + cache_scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + The score for each amino acid in cached peptides. + batch: int + Number of spectra in the batch. + + Returns + ------- + output_tokens : torch.Tensor of shape (n_spectra, max_length) + The token for each amino acid in the output peptides. + output_scores : torch.Tensor of shape + (n_spectra, max_length, n_amino_acids) + The score for each amino acid in cached peptides. + """ + # Sizes. + length = self.max_length + 1 # L + vocab = self.decoder.vocab_size + 1 # V + + # Create output tensors for top scoring peptides and their scores. + output_scores = torch.full( + size=(batch, length, vocab), fill_value=torch.nan + ) + output_scores = output_scores.type_as(cache_scores) + output_tokens = torch.zeros(batch, length).type_as(cache_tokens) + + # Return the top scoring peptide (fitting precursor mass if possible). + for spec_idx in range(batch): + cache = cache_pred_score[spec_idx][ + len(cache_pred_score[spec_idx][0]) == 0 + ] + # Skip this spectrum if it doesn't have any finished beams. + if len(cache) == 0: + continue + _, top_score_idx = max(cache, key=operator.itemgetter(0)) + output_tokens[spec_idx, :] = cache_tokens[top_score_idx, :] + output_scores[spec_idx, :, :] = cache_scores[top_score_idx, :, :] + + return output_tokens, output_scores + + def _get_topk_beams( + self, scores: torch.tensor, tokens: torch.tensor, batch: int, idx: int + ) -> Tuple[torch.tensor, torch.tensor]: + """ + Find top-k beams with highest confidences and continue decoding those. + + Discontinue decoding for beams where the stop token was predicted. + + Parameters + ---------- + scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + Output scores of the model. + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + Output token of the model corresponding to amino acid sequences. + batch: int + Number of spectra in the batch. + idx : int + Index to be considered in the current decoding step. + + Returns + ------- + scores : torch.Tensor of shape + (n_spectra * n_beams, max_length, n_amino_acids) + Output scores of the model. + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) + Output token of the model corresponding to amino acid sequences. + """ + beam = self.n_beams # S + vocab = self.decoder.vocab_size + 1 # V + + # Reshape to group by spectrum (B for "batch"). + scores = einops.rearrange(scores, "(B S) L V -> B L V S", S=beam) + tokens = einops.rearrange(tokens, "(B S) L -> B L S", S=beam) + prev_tokens = einops.repeat( + tokens[:, :idx, :], "B L S -> B L V S", V=vocab + ) + + # Get the previous tokens and scores. + prev_scores = torch.gather( + scores[:, :idx, :, :], dim=2, index=prev_tokens + ) + prev_scores = einops.repeat( + prev_scores[:, :, 0, :], "B L S -> B L (V S)", V=vocab + ) + + # Get scores for all possible beams at this step. + step_scores = torch.zeros(batch, idx + 1, beam * vocab).type_as(scores) + step_scores[:, :idx, :] = prev_scores + step_scores[:, idx, :] = einops.rearrange( + scores[:, idx, :, :], "B V S -> B (V S)" + ) + + # Mask out terminated beams. Include delta mass induced termination. + extended_prev_tokens = einops.repeat( + tokens[:, : idx + 1, :], "B L S -> B L V S", V=vocab + ) + finished_mask = ( + einops.rearrange(extended_prev_tokens, "B L V S -> B L (V S)") + == self.stop_token + ).any(axis=1) + # Mask out the index '0', i.e. padding token, by default. + finished_mask[:, :beam] = True + + # Figure out the top K decodings. + _, top_idx = torch.topk( + step_scores.nanmean(dim=1) * (~finished_mask).float(), beam + ) + v_idx, s_idx = np.unravel_index(top_idx.cpu(), (vocab, beam)) + s_idx = einops.rearrange(s_idx, "B S -> (B S)") + b_idx = einops.repeat(torch.arange(batch), "B -> (B S)", S=beam) + + # Record the top K decodings. + tokens[:, :idx, :] = einops.rearrange( + prev_tokens[b_idx, :, 0, s_idx], "(B S) L -> B L S", S=beam + ) + tokens[:, idx, :] = torch.tensor(v_idx) + scores[:, : idx + 1, :, :] = einops.rearrange( + scores[b_idx, : idx + 1, :, s_idx], "(B S) L V -> B L V S", S=beam + ) + scores = einops.rearrange(scores, "B L V S -> (B S) L V") + tokens = einops.rearrange(tokens, "B L S -> (B S) L") + return scores, tokens def _forward_step( self, @@ -346,7 +813,7 @@ def validation_step( def predict_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args - ) -> Tuple[torch.Tensor, torch.Tensor, List[str], torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[str]], torch.Tensor]: """ A single prediction step. @@ -362,7 +829,7 @@ def predict_step( The spectrum identifiers. precursors : torch.Tensor Precursor information for each spectrum. - peptides : List[str] + peptides : List[List[str]] The predicted peptide sequences for each spectrum. aa_scores : torch.Tensor of shape (n_spectra, length, n_amino_acids) The individual amino acid scores for each prediction. @@ -408,21 +875,24 @@ def on_predict_epoch_end( return for batch in results: for step in batch: - for spectrum_i, precursor, peptide, aa_scores in zip(*step): - peptide = peptide[1:] - peptide_tokens = re.split(r"(?<=.)(?=[A-Z])", peptide) - # Take the scores of the most probable amino acids. - top_aa_scores = torch.max( - aa_scores[1 : len(peptide_tokens) + 1], axis=1 - )[0] - peptide_score = torch.mean(top_aa_scores).detach().item() + for spectrum_i, precursor, aa_tokens, aa_scores in zip(*step): + # Get peptide sequence, amino acid and peptide-level + # confidence scores to write to output file. + ( + peptide, + aa_tokens, + peptide_score, + aa_scores, + ) = self._get_output_peptide_and_scores( + aa_tokens, aa_scores + ) # Compare the experimental vs calculated precursor m/z. _, precursor_charge, precursor_mz = precursor precursor_charge = int(precursor_charge.item()) precursor_mz = precursor_mz.item() try: calc_mz = self.peptide_mass_calculator.mass( - peptide_tokens, precursor_charge + aa_tokens, precursor_charge ) delta_mass_ppm = [ _calc_mass_error( @@ -445,9 +915,7 @@ def on_predict_epoch_end( # Subtract one if the precursor m/z tolerance is violated. if not is_within_precursor_mz_tol: peptide_score -= 1 - aa_scores = ",".join( - reversed(list(map("{:.5f}".format, top_aa_scores))) - ) + self.out_writer.psms.append( ( peptide, @@ -460,6 +928,52 @@ def on_predict_epoch_end( ), ) + def _get_output_peptide_and_scores( + self, aa_tokens: List[str], aa_scores: torch.Tensor + ) -> Tuple[str, List[str], float, str]: + """ + Get peptide to output, amino acid and peptide-level confidence scores. + + Parameters + ---------- + aa_tokens : List[str] + Amino acid tokens of the peptide sequence. + aa_scores : torch.Tensor + Amino acid-level confidence scores for the predicted sequence. + + Returns + ------- + peptide : str + Peptide sequence. + aa_tokens : List[str] + Amino acid tokens of the peptide sequence. + peptide_score : str + Peptide-level confidence score. + aa_scores : str + Amino acid-level confidence scores for the predicted sequence. + """ + # Omit stop token. + aa_tokens = aa_tokens[1:] if self.decoder.reverse else aa_tokens[:-1] + peptide = "".join(aa_tokens) + + # If this is a non-finished beam (after exceeding `max_length`), return + # a dummy (empty) peptide and NaN scores. + if len(peptide) == 0: + aa_tokens = [] + + # Take scores corresponding to the predicted amino acids. Reverse tokens + # to correspond with correct amino acids as needed. + step = -1 if self.decoder.reverse else 1 + top_aa_scores = [ + aa_score[self.decoder._aa2idx[aa_token]].item() + for aa_score, aa_token in zip(aa_scores, aa_tokens[::step]) + ][::step] + + # Get peptide-level score from amino acid-level scores. + peptide_score = _aa_to_pep_score(top_aa_scores) + aa_scores = ",".join(list(map("{:.5f}".format, top_aa_scores))) + return peptide, aa_tokens, peptide_score, aa_scores + def _log_history(self) -> None: """ Write log to console, if requested. @@ -572,3 +1086,20 @@ def _calc_mass_error( The mass error in ppm. """ return (calc_mz - (obs_mz - isotope * 1.00335 / charge)) / obs_mz * 10**6 + + +def _aa_to_pep_score(aa_scores: List[float]) -> float: + """ + Calculate peptide-level confidence score from amino acid level scores. + + Parameters + ---------- + aa_scores : List[float] + Amino acid level confidence scores. + + Returns + ------- + float + Peptide confidence score. + """ + return np.mean(aa_scores) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index fd80df1a..e05726e3 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -110,6 +110,7 @@ def _execute_existing( max_charge=config["max_charge"], precursor_mass_tol=config["precursor_mass_tol"], isotope_error_range=config["isotope_error_range"], + n_beams=config["n_beams"], n_log=config["n_log"], out_writer=out_writer, ) @@ -261,6 +262,7 @@ def train( max_charge=config["max_charge"], precursor_mass_tol=config["precursor_mass_tol"], isotope_error_range=config["isotope_error_range"], + n_beams=config["n_beams"], n_log=config["n_log"], tb_summarywriter=config["tb_summarywriter"], warmup_iters=config["warmup_iters"], diff --git a/setup.cfg b/setup.cfg index b151fa69..c417605a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ python_requires = >=3.8 install_requires = appdirs click - depthcharge-ms>=0.0.1 + depthcharge-ms>=0.1.0 numpy pandas psutil diff --git a/tests/test_unit.py b/tests/test_unit.py index 33bccf68..2302d8f7 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -3,12 +3,14 @@ import tempfile import github +import numpy as np import pytest +import torch from casanovo import casanovo from casanovo import utils -from casanovo.denovo.model import Spec2Pep from casanovo.denovo.evaluate import aa_match_batch, aa_match_metrics +from casanovo.denovo.model import Spec2Pep, _aa_to_pep_score def test_version(): @@ -102,6 +104,9 @@ def request(self, *args, **kwargs): def test_tensorboard(): + """ + Test tensorboard.SummaryWriter object created only when folder path passed + """ model = Spec2Pep(tb_summarywriter="test_path") assert model.tb_summarywriter is not None @@ -109,13 +114,480 @@ def test_tensorboard(): assert model.tb_summarywriter is None +def test_aa_to_pep_score(): + """ + Test how peptide confidence scores are derived from amino acid scores. + Currently, AA scores are just averaged. + """ + assert ( + _aa_to_pep_score( + [ + 0.0, + 0.5, + 1.0, + ] + ) + == 0.5 + ) + + +def test_beam_search_decode(): + """ + Test beam search decoding and its sub-functions + """ + model = Spec2Pep(n_beams=4, residues="massivekb") + + # Sizes. + batch = 1 # B + length = model.max_length + 1 # L + vocab = model.decoder.vocab_size + 1 # V + beam = model.n_beams # S + idx = 4 + + # Initialize scores and tokens. + scores = torch.full( + size=(batch, length, vocab, beam), fill_value=torch.nan + ) + is_beam_prec_fit = torch.zeros(batch * beam, dtype=torch.bool) + + # Ground truth peptide is "PEPK". + precursors = torch.tensor([469.2536487, 2.0, 235.63410081688]).repeat( + beam * batch, 1 + ) + tokens = torch.zeros(batch * beam, length).long() + + tokens[0, :idx] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["K"], + ] + ) + tokens[1, :idx] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["R"], + ] + ) + tokens[2, :idx] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["G"], + ] + ) + tokens[3, :idx] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["$"], + ] + ) + + # Test _terminate_finished_beams(). + finished_beams_idx, updated_tokens = model._terminate_finished_beams( + tokens=tokens, + precursors=precursors, + is_beam_prec_fit=is_beam_prec_fit, + idx=idx, + ) + + assert torch.equal(finished_beams_idx, torch.tensor([0, 1, 3])) + assert torch.equal( + updated_tokens[:, idx], + torch.tensor([model.stop_token, model.stop_token, 0, 0]), + ) + + # Test _create_beamsearch_cache() and _cache_finished_beams(). + tokens = torch.zeros(batch, length, beam).long() + + ( + cache_scores, + cache_tokens, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + ) = model._create_beamsearch_cache(scores, tokens) + + scores = cache_scores.clone() + for i in range(idx): + scores[:, i, :] = 1 + scores[1, i, updated_tokens[1, i].item()] = 2 + + model._cache_finished_beams( + finished_beams_idx, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + cache_tokens, + cache_scores, + updated_tokens, + scores, + is_beam_prec_fit, + idx, + ) + + assert cache_next_idx[0] == 3 + # Keep track of peptides that should be in cache. + correct_pep = 0 + for pep in cache_pred_seq[0]: + correct_pep += ( + torch.equal(pep, torch.tensor([4, 14, 4, 13])) + or torch.equal(pep, torch.tensor([4, 14, 4, 18])) + or torch.equal(pep, torch.tensor([4, 14, 4])) + ) + assert correct_pep == 3 + # Check if precursor fitting and non-fitting peptides cached correctly. + assert len(cache_pred_score[0][0]) == 1 + assert len(cache_pred_score[0][1]) == 2 + + # Test _get_top_peptide(). + output_tokens, output_scores = model._get_top_peptide( + cache_pred_score, cache_tokens, cache_scores, batch + ) + + # Check if output equivalent to "PEPK". + assert torch.equal(output_tokens[0], cache_tokens[0]) + + # If no peptides are finished + dummy_cache_pred_score = {0: [[], []]} + + dummy_output_tokens, dummy_output_scores = model._get_top_peptide( + dummy_cache_pred_score, cache_tokens, cache_scores, batch + ) + + # Check if output equivalent to zero tensor + assert sum(dummy_output_tokens[0]).item() == 0 + + # Test _get_topk_beams() + # Generate scores for the non-terminated beam + scores[2, idx, :] = 1 + + for i in range(1, 5): + scores[2, idx, i] = i + 1 + + new_scores, new_tokens = model._get_topk_beams( + scores=scores, tokens=updated_tokens, batch=batch, idx=idx + ) + + expected_tokens = torch.tensor( + [ + [4, 14, 4, 1, 4], + [4, 14, 4, 1, 3], + [4, 14, 4, 1, 2], + [4, 14, 4, 1, 1], + ] + ) + + expected_scores = torch.ones(beam, vocab) + + for i in range(1, 5): + expected_scores[:, i] = i + 1 + + assert torch.equal(new_tokens[:, : idx + 1], expected_tokens) + assert torch.equal(new_scores[:, idx, :], expected_scores) + + # Test beam_search_decode(). + spectra = torch.zeros(1, 5, 2) + precursors = torch.tensor([[469.2536487, 2.0, 235.63410081688]]) + model_scores, model_tokens = model.beam_search_decode(spectra, precursors) + + assert model_tokens.shape[0] == 1 + assert model.stop_token in model_tokens + + # Test output if decoding loop isn't stopped with termination of all beams + model.max_length = 0 + model_scores, model_tokens = model.beam_search_decode(spectra, precursors) + assert torch.equal(model_tokens, torch.tensor([[0]])) + model.max_length = 100 + + # Re-initialize scores and tokens to further test caching functionality. + scores_v2 = torch.full( + size=(batch * beam, length, vocab), fill_value=torch.nan + ) + tokens_v2 = torch.zeros(batch * beam, length).long() + + tokens_v2[0, : idx + 1] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["K"], + model.decoder._aa2idx["K"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["$"], + ] + ) + tokens_v2[1, : idx + 1] = torch.tensor( + [ + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["K"], + model.decoder._aa2idx["$"], + ] + ) + tokens_v2[2, : idx + 1] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["R"], + model.decoder._aa2idx["$"], + ] + ) + tokens_v2[3, : idx + 1] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["M"], + model.decoder._aa2idx["K"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["$"], + ] + ) + + # Test if fitting replaces non-fitting in the cache and only higher scoring + # non-fitting replaces non-fitting. + for i in range(idx + 1): + scores_v2[:, i, :] = 1 + scores_v2[0, i, tokens_v2[0, i].item()] = 4 + scores_v2[1, i, tokens_v2[1, i].item()] = 0.5 + scores_v2[2, i, tokens_v2[2, i].item()] = 3 + scores_v2[3, i, tokens_v2[3, i].item()] = 0.4 + + finished_beams_idx_v2 = torch.tensor([0, 1, 2, 3]) + is_beam_prec_fit_v2 = torch.BoolTensor([False, True, False, False]) + + model._cache_finished_beams( + finished_beams_idx_v2, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + cache_tokens, + cache_scores, + tokens_v2, + scores_v2, + is_beam_prec_fit_v2, + idx + 1, + ) + + assert cache_next_idx[0] == 4 + # Check if precursor fitting and non-fitting peptides cached correctly. + assert len(cache_pred_score[0][0]) == 2 + assert len(cache_pred_score[0][1]) == 2 + + # Keep track of peptides that should (not) be in cache. + correct_pep = 0 + wrong_pep = 0 + + for pep in cache_pred_seq[0]: + if ( + torch.equal(pep, torch.tensor([4, 13, 13, 4])) + or torch.equal(pep, torch.tensor([14, 4, 4, 13])) + or torch.equal(pep, torch.tensor([4, 14, 4, 18])) + ): + correct_pep += 1 + elif torch.equal(pep, torch.tensor([4, 15, 13, 4])): + wrong_pep += 1 + assert correct_pep == 3 + assert wrong_pep == 0 + + # Test for a single beam. + model = Spec2Pep(n_beams=1) + + # Sizes. + batch = 1 # B + length = model.max_length + 1 # L + vocab = model.decoder.vocab_size + 1 # V + beam = model.n_beams # S + idx = 4 + + # Initialize scores and tokens. + scores = torch.full( + size=(batch, length, vocab, beam), fill_value=torch.nan + ) + is_beam_prec_fit = torch.zeros(batch * beam, dtype=torch.bool) + + # Ground truth peptide is "PEPK" + precursors = torch.tensor([469.2536487, 2.0, 235.63410081688]).repeat( + beam * batch, 1 + ) + tokens = torch.zeros(batch * beam, length).long() + + tokens[0, :idx] = torch.tensor( + [ + model.decoder._aa2idx["P"], + model.decoder._aa2idx["E"], + model.decoder._aa2idx["P"], + model.decoder._aa2idx["K"], + ] + ) + + # Test _terminate_finished_beams(). + finished_beams_idx, updated_tokens = model._terminate_finished_beams( + tokens=tokens, + precursors=precursors, + is_beam_prec_fit=is_beam_prec_fit, + idx=idx, + ) + + assert torch.equal(finished_beams_idx, torch.tensor([0])) + assert torch.equal( + updated_tokens[:, idx], torch.tensor([model.stop_token]) + ) + + # Test _create_beamsearch_cache() and _cache_finished_beams(). + tokens = torch.zeros(batch, length, beam).long() + + ( + cache_scores, + cache_tokens, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + ) = model._create_beamsearch_cache(scores, tokens) + + scores = cache_scores.clone() + for i in range(idx): + scores[:, i, :] = 1 + + model._cache_finished_beams( + finished_beams_idx, + cache_next_idx, + cache_pred_seq, + cache_pred_score, + cache_tokens, + cache_scores, + updated_tokens, + scores, + is_beam_prec_fit, + idx, + ) + + assert cache_next_idx[0] == 1 + # Keep track of peptides that should be in cache. + correct_pep = 0 + for pep in cache_pred_seq[0]: + correct_pep += torch.equal(pep, torch.tensor([4, 14, 4, 13])) + assert correct_pep == 1 + # Check if precursor fitting and non-fitting peptides cached correctly. + assert len(cache_pred_score[0][0]) == 1 + assert len(cache_pred_score[0][1]) == 0 + + # Test _get_top_peptide(). + output_tokens, output_scores = model._get_top_peptide( + cache_pred_score, cache_tokens, cache_scores, batch + ) + + # Check if output equivalent to "PEPK". + assert torch.equal(output_tokens[0], cache_tokens[0]) + + # Test _terminate_finished_beams for tokens with negative mass + model = Spec2Pep(n_beams=2, residues="massivekb") + # Sizes: + batch = 1 # B + length = model.max_length + 1 # L + vocab = model.decoder.vocab_size + 1 # V + beam = model.n_beams # S + idx = 2 + # Initialize scores and tokens: + scores = torch.full( + size=(batch, length, vocab, beam), fill_value=torch.nan + ) + is_beam_prec_fit = (batch * beam) * [False] + # Ground truth peptide is "-17.027GK" + precursors = torch.tensor([186.100442485, 2.0, 94.05749770938]).repeat( + beam * batch, 1 + ) + tokens = torch.zeros(batch * beam, length).long() + tokens[0, :idx] = torch.tensor( + [ + model.decoder._aa2idx["G"], + model.decoder._aa2idx["K"], + ] + ) + + tokens[1, :idx] = torch.tensor( + [ + model.decoder._aa2idx["A"], + model.decoder._aa2idx["K"], + ] + ) + # Test _terminate_finished_beams() + finished_beams_idx, updated_tokens = model._terminate_finished_beams( + tokens=tokens, + precursors=precursors, + is_beam_prec_fit=is_beam_prec_fit, + idx=idx, + ) + assert torch.equal(finished_beams_idx, torch.tensor([1])) + assert torch.equal( + updated_tokens[:, idx], + torch.tensor([0, model.stop_token]), + ) + + +def test_get_output_peptide_and_scores(): + """ + Test output peptides and amino acid/peptide-level scores have correct format. + """ + # Test a common case with reverse decoding (C- to N-terminus) + model = Spec2Pep() + aa_tokens = [model.decoder._idx2aa[model.stop_token], "G", "K"] + aa_scores = torch.zeros(model.max_length, model.decoder.vocab_size + 1) + aa_scores[0][model.decoder._aa2idx["K"]] = 1 + aa_scores[1][model.decoder._aa2idx["G"]] = 1 + + ( + peptide, + aa_tokens, + peptide_score, + aa_scores, + ) = model._get_output_peptide_and_scores(aa_tokens, aa_scores) + assert peptide == "GK" + assert peptide_score == 1 + assert aa_scores == "1.00000,1.00000" + + # Test a case with straigth decoding (N- to C-terminus) + model.decoder.reverse = False + aa_tokens = ["G", "K", model.decoder._idx2aa[model.stop_token]] + aa_scores = torch.zeros(model.max_length, model.decoder.vocab_size + 1) + aa_scores[0][model.decoder._aa2idx["G"]] = 1 + aa_scores[1][model.decoder._aa2idx["K"]] = 1 + + ( + peptide, + aa_tokens, + peptide_score, + aa_scores, + ) = model._get_output_peptide_and_scores(aa_tokens, aa_scores) + assert peptide == "GK" + assert peptide_score == 1 + assert aa_scores == "1.00000,1.00000" + + # Test when predicted peptide is empty + aa_tokens = ["", ""] + + ( + peptide, + aa_tokens, + peptide_score, + aa_scores, + ) = model._get_output_peptide_and_scores(aa_tokens, aa_scores) + assert peptide == "" + assert np.isnan(peptide_score) + assert aa_scores == "" + + def test_eval_metrics(): """ - Test that peptide and amino acid-level evaluation metrics. - Predicted AAs are considered correct match if they're <0.1Da from - the corresponding ground truth (GT) AA with either a suffix or - prefix <0.5Da from GT. A peptide prediction is correct if all - its AA are correct matches. + Test peptide and amino acid-level evaluation metrics. + Predicted AAs are considered correct if they are <0.1Da from the + corresponding ground truth (GT) AA with either a suffix or prefix <0.5Da + from GT. A peptide prediction is correct if all its AA are correct matches. """ model = Spec2Pep() @@ -145,6 +617,6 @@ def test_eval_metrics(): aa_matches, n_gt_aa, n_pred_aa ) - assert round(2 / 8, 3) == round(pep_precision, 3) - assert round(26 / 40, 3) == round(aa_recall, 3) - assert round(26 / 41, 3) == round(aa_precision, 3) + assert 2 / 8 == pytest.approx(pep_precision) + assert 26 / 40 == pytest.approx(aa_recall) + assert 26 / 41 == pytest.approx(aa_precision)