Skip to content

Commit

Permalink
Add stop tokens (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
wfondrie authored May 9, 2024
1 parent 3ca2297 commit 1b53a35
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
## [v0.4.7]
### Fixed
- Add stop and start tokens for `AnnotatedSpectrumDataset`, when available.
- When `reverse` is used for the `PeptideTokenizer`, automatically reverse the decoded peptide.

## [v0.4.6]
### Added
Expand Down
4 changes: 3 additions & 1 deletion depthcharge/data/spectrum_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def _to_tensor(
"""
batch = super()._to_tensor(batch)
batch[self.annotations] = self.tokenizer.tokenize(
batch[self.annotations]
batch[self.annotations],
add_start=self.tokenizer.start_token is not None,
add_stop=self.tokenizer.stop_token is not None,
)
return batch

Expand Down
42 changes: 40 additions & 2 deletions depthcharge/tokenizers/peptides.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,50 @@ def split(self, sequence: str) -> list[str]:

return pep

def detokenize(
self,
tokens: torch.Tensor,
join: bool = True,
trim_start_token: bool = True,
trim_stop_token: bool = True,
) -> list[str] | list[list[str]]:
"""Retreive sequences from tokens.
Parameters
----------
tokens : torch.Tensor of shape (n_sequences, max_length)
The zero-padded tensor of integerized tokens to decode.
join : bool, optional
Join tokens into strings?
trim_start_token : bool, optional
Remove the start token from the beginning of a sequence.
trim_stop_token : bool, optional
Remove the stop token from the end of a sequence.
Returns
-------
list[str] or list[list[str]]
The decoded sequences each as a string or list or strings.
"""
decoded = super().detokenize(
tokens=tokens,
join=join,
trim_start_token=trim_start_token,
trim_stop_token=trim_start_token,
)

if self.reverse:
decoded = [d[::-1] for d in decoded]

return decoded

@classmethod
def from_proforma(
cls,
sequences: Iterable[str],
replace_isoleucine_with_leucine: bool = False,
reverse: bool = True,
reverse: bool = False,
start_token: str | None = None,
stop_token: str | None = "$",
) -> PeptideTokenizer:
Expand Down Expand Up @@ -238,7 +276,7 @@ def from_proforma(
@staticmethod
def from_massivekb(
replace_isoleucine_with_leucine: bool = False,
reverse: bool = True,
reverse: bool = False,
start_token: str | None = None,
stop_token: str | None = "$",
) -> MskbPeptideTokenizer:
Expand Down
16 changes: 12 additions & 4 deletions tests/unit_tests/test_data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,18 @@ def test_indexing(tokenizer, mgf_small, tmp_path):
1,
14,
)
torch.testing.assert_close(spec["seq"], tokenizer.tokenize(["LESLIEK"]))
torch.testing.assert_close(
spec["seq"], tokenizer.tokenize(["LESLIEK"], add_stop=True)
)

spec2 = dataset[3]
assert spec2["mz_array"].shape == (
1,
24,
)
torch.testing.assert_close(spec2["seq"], tokenizer.tokenize(["EDITHR"]))
torch.testing.assert_close(
spec2["seq"], tokenizer.tokenize(["EDITHR"], add_stop=True)
)


def test_load(tokenizer, tmp_path, mgf_small):
Expand All @@ -106,11 +110,15 @@ def test_load(tokenizer, tmp_path, mgf_small):
spec = dataset[0]
assert len(spec) == 8
assert spec["mz_array"].shape == (1, 14)
torch.testing.assert_close(spec["seq"], tokenizer.tokenize(["LESLIEK"]))
torch.testing.assert_close(
spec["seq"], tokenizer.tokenize(["LESLIEK"], add_stop=True)
)

spec2 = dataset[1]
assert spec2["mz_array"].shape == (1, 24)
torch.testing.assert_close(spec2["seq"], tokenizer.tokenize(["EDITHR"]))
torch.testing.assert_close(
spec2["seq"], tokenizer.tokenize(["EDITHR"], add_stop=True)
)

dataset = SpectrumDataset.from_lance(db_path, 1)
spec = dataset[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_data/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_ann_spectrum_loader(mgf_small):
assert isinstance(batch["mz_array"], torch.Tensor)
torch.testing.assert_close(
batch["seq"][0, ...],
tokenizer.tokenize(["LESLIEK"]),
tokenizer.tokenize(["LESLIEK"], add_stop=True),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_tokenizers/test_peptides.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_proforma_init():
assert tokens == list("KEILSEL")
tokens = proforma.tokenize(["LESLIEK"])
orig = proforma.detokenize(tokens)
assert orig == ["KEILSEL"]
assert orig == ["LESLIEK"]

tokens = proforma.tokenize("LESLIEK", True, True, True)[0]
assert "".join(tokens) == "KEILSEL$"
Expand Down

0 comments on commit 1b53a35

Please sign in to comment.