Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam search peptide decoding #87

Merged
merged 44 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0710ee8
Add beam search
melihyilmaz Nov 2, 2022
0476618
Delete print statements
melihyilmaz Nov 2, 2022
bfbffa6
Automatically download model weights (#68) (#88)
melihyilmaz Nov 3, 2022
bc7d350
Automatically download model weights (#68) (#89)
melihyilmaz Nov 3, 2022
dc77f1b
Break beam search to testable subfunctions
melihyilmaz Nov 5, 2022
bab7072
Automatically download model weights (#68)
melihyilmaz Nov 5, 2022
3f0576e
Fix precursor m/z termination and filtering
melihyilmaz Nov 6, 2022
6f39e1c
Add unit testing for beam search
melihyilmaz Nov 6, 2022
5d61868
Add beamsearch comments and fix formatting
melihyilmaz Nov 6, 2022
be05748
Merge branch 'main' into beamsearch_melih
melihyilmaz Nov 8, 2022
3c81755
Address requested changes and minor fixes
melihyilmaz Nov 9, 2022
75c9b50
Add more unit tests for beam search
melihyilmaz Nov 9, 2022
b962453
Check NH3 loss for early stopping
melihyilmaz Nov 14, 2022
9050fc8
Consistent parameter order
bittremieux Nov 14, 2022
592efb0
Update docstrings
bittremieux Nov 14, 2022
426ece8
Remove unused precursors parameter
bittremieux Nov 14, 2022
fcb006c
Update beam matching mask in a level higher
bittremieux Nov 14, 2022
6bc2ba4
Minor refactoring to avoid code duplication
bittremieux Nov 14, 2022
cbdcacc
Update imports
bittremieux Nov 14, 2022
646f9dc
Simplification refactoring
bittremieux Nov 14, 2022
616c0c4
Fix unit tests
bittremieux Nov 14, 2022
66c1b2e
Merge remote-tracking branch 'origin/main' into beamsearch_melih
bittremieux Nov 15, 2022
c705b0e
Simplify predicted peptide caching
bittremieux Nov 15, 2022
cbeefa7
Simplify predicted peptide caching
bittremieux Nov 15, 2022
6e3b6da
Simplify predicted peptide caching
bittremieux Nov 15, 2022
1a3bcb1
Unify predicted peptide caching
bittremieux Nov 15, 2022
c2ec4d2
Restrict tensor reshape to subfunction and minor fixes
melihyilmaz Nov 15, 2022
62c51ae
Finish beams when all isotopes exceed the precursor m/z tolerance
bittremieux Nov 15, 2022
57b5b31
Generalize look-ahead for tokens with negative mass
bittremieux Nov 15, 2022
b65aaca
Remove greedy decoding functionality
bittremieux Nov 15, 2022
eb08c5b
Merge branch 'main' into beamsearch_melih
melihyilmaz Nov 15, 2022
a2f9a3d
Handle case with unfinished beams and add test
melihyilmaz Nov 15, 2022
bc64949
Merge branch 'main' into beamsearch_melih
melihyilmaz Nov 16, 2022
63262e9
Upgrade required depthcharge version
melihyilmaz Nov 16, 2022
03b9172
Use detokenize function
melihyilmaz Nov 16, 2022
412093b
Add test for negative mass-aware termination
melihyilmaz Nov 16, 2022
3126e14
Fix egative mass-aware beam termination
melihyilmaz Nov 16, 2022
053967e
Minor refactoring
bittremieux Nov 16, 2022
bdfa915
Add test for dummy output at max length
melihyilmaz Nov 17, 2022
0c6c0f0
Fixed and refactored peptide and scocre mzTab outputs
melihyilmaz Nov 17, 2022
e50bf90
Add tests for peptide and score output formatting
melihyilmaz Nov 17, 2022
7e934f5
Small fixes
bittremieux Nov 17, 2022
10df5e0
Update changelog
bittremieux Nov 18, 2022
f4fa6c8
Fix changelog update
bittremieux Nov 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions casanovo/denovo/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,15 @@ def aa_match_batch(
"""
aa_matches_batch, n_aa1, n_aa2 = [], 0, 0
for peptide1, peptide2 in zip(peptides1, peptides2):
tokens1 = re.split(r"(?<=.)(?=[A-Z])", peptide1)
tokens2 = re.split(r"(?<=.)(?=[A-Z])", peptide2)
n_aa1, n_aa2 = n_aa1 + len(tokens1), n_aa2 + len(tokens2)
if isinstance(peptide1, str):
peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1)
if isinstance(peptide2, str):
peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2)
n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2)
aa_matches_batch.append(
aa_match(
tokens1,
tokens2,
peptide1,
peptide2,
aa_dict,
cum_mass_threshold,
ind_mass_threshold,
Expand Down
72 changes: 60 additions & 12 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Union,
)
from heapq import heappop, heappush, heappushpop
from operator import itemgetter

import depthcharge.masses
import einops
Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(

def forward(
self, spectra: torch.Tensor, precursors: torch.Tensor
) -> Tuple[List[str], torch.Tensor]:
) -> Tuple[List[List[str]], torch.Tensor]:
"""
Predict peptide sequences for a batch of MS/MS spectra.

Expand All @@ -186,7 +187,7 @@ def forward(

Returns
-------
peptides : List[str]
peptides : List[List[str]]
The predicted peptide sequences for each spectrum.
aa_scores : torch.Tensor of shape (n_spectra, length, n_amino_acids)
The individual amino acid scores for each prediction.
Expand All @@ -195,7 +196,18 @@ def forward(
spectra.to(self.encoder.device),
precursors.to(self.decoder.device),
)
return [self.decoder.detokenize(t) for t in tokens], aa_scores
peptides = []
melihyilmaz marked this conversation as resolved.
Show resolved Hide resolved
for t in tokens:
sequence = [self.decoder._idx2aa.get(i.item(), "") for i in t]
if "$" in sequence:
idx = sequence.index("$")
sequence = sequence[: idx + 1]

if self.decoder.reverse:
sequence = list(reversed(sequence))
peptides += [sequence]

return peptides, aa_scores

def beam_search_decode(
self, spectra: torch.Tensor, precursors: torch.Tensor
Expand Down Expand Up @@ -269,13 +281,11 @@ def beam_search_decode(
for idx in range(1, self.max_length + 1):
scores = einops.rearrange(scores, "B L V S -> (B S) L V")
melihyilmaz marked this conversation as resolved.
Show resolved Hide resolved
tokens = einops.rearrange(tokens, "B L S -> (B S) L")

# Terminate beams exceeding precursor m/z tolerance
# and track all terminated beams
finished_beams_idx, tokens = self._terminate_finished_beams(
tokens, precursors, is_beam_prec_fit, idx
)

# Cache terminated beams, group and order by fitting precursor m/z
# and confidence score
self._cache_finished_beams(
Expand Down Expand Up @@ -407,6 +417,11 @@ def _terminate_finished_beams(
tokens : torch.Tensor of size (n_spectra * n_beams, max_length)
Output token of the model corresponding to amino acid sequences.
"""
# Check if N-terminal NH3 loss is in the vocabulary
is_nh3_loss_in_vocab = (
self.peptide_mass_calculator.masses.get("-17.027", "") != ""
melihyilmaz marked this conversation as resolved.
Show resolved Hide resolved
)

# Terminate beams which exceed the precursor m/z
for beam_i in range(len(tokens)):
# Check only non-terminated beams
Expand All @@ -418,8 +433,10 @@ def _terminate_finished_beams(
else:
precursor_mz = precursors[beam_i, 2].item()
precursor_charge = precursors[beam_i, 1].item()
peptide_seq = self.decoder.detokenize(tokens[beam_i][:idx])

peptide_seq = [
melihyilmaz marked this conversation as resolved.
Show resolved Hide resolved
self.decoder._idx2aa.get(i.item(), "")
for i in tokens[beam_i][:idx]
]
try:
pred_mz = self.peptide_mass_calculator.mass(
seq=peptide_seq, charge=precursor_charge
Expand All @@ -444,6 +461,37 @@ def _terminate_finished_beams(
abs(d) < self.precursor_mass_tol
for d in delta_mass_ppm
)

# Don't terminate if m/z difference is by NH3 loss
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
if (
exceeds_precursor_mz_tol
and not is_within_precursor_mz_tol
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
and is_nh3_loss_in_vocab
):
alt_peptide_seq = peptide_seq + ["-17.027"]
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
pred_mz = self.peptide_mass_calculator.mass(
seq=alt_peptide_seq, charge=precursor_charge
)
delta_mass_ppm = [
_calc_mass_error(
pred_mz,
precursor_mz,
precursor_charge,
isotope,
)
for isotope in range(
self.isotope_error_range[0],
self.isotope_error_range[1] + 1,
)
]
is_precursor_mz_tol_w_loss = any(
abs(d) < self.precursor_mass_tol
for d in delta_mass_ppm
)
exceeds_precursor_mz_tol = (
melihyilmaz marked this conversation as resolved.
Show resolved Hide resolved
not is_precursor_mz_tol_w_loss
)

except KeyError:
(
pred_mz,
Expand Down Expand Up @@ -684,7 +732,7 @@ def _get_top_peptide(
if len(cached_fitting) > 0
else cached_nonfitting
)
_, top_score_idx = max(cache, key=lambda item: item[0])
_, top_score_idx = max(cache, key=itemgetter(1))

output_tokens[spec_idx, :] = cache_tokens[top_score_idx, :]
output_scores[spec_idx, :, :] = cache_scores[top_score_idx, :, :]
Expand Down Expand Up @@ -951,7 +999,7 @@ def validation_step(

def predict_step(
self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args
) -> Tuple[torch.Tensor, torch.Tensor, List[str], torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[List[str]], torch.Tensor]:
"""
A single prediction step.

Expand All @@ -967,7 +1015,7 @@ def predict_step(
The spectrum identifiers.
precursors : torch.Tensor
Precursor information for each spectrum.
peptides : List[str]
peptides : List[List[str]]
The predicted peptide sequences for each spectrum.
aa_scores : torch.Tensor of shape (n_spectra, length, n_amino_acids)
The individual amino acid scores for each prediction.
Expand Down Expand Up @@ -1014,8 +1062,8 @@ def on_predict_epoch_end(
for batch in results:
for step in batch:
for spectrum_i, precursor, peptide, aa_scores in zip(*step):
peptide = peptide[1:]
peptide_tokens = re.split(r"(?<=.)(?=[A-Z])", peptide)
peptide_tokens = peptide[1:]
peptide = "".join(peptide_tokens)
# Take scores corresponding to the predicted amino acids.
top_aa_scores = [
aa_scores[idx][self.decoder._aa2idx[aa]].item()
Expand Down