diff --git a/CHANGELOG.md b/CHANGELOG.md index e49edc2..5218f91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v0.4.7] +### Fixed +- Add stop and start tokens for `AnnotatedSpectrumDataset`, when available. +- When `reverse` is used for the `PeptideTokenizer`, automatically reverse the decoded peptide. ## [v0.4.6] ### Added diff --git a/depthcharge/data/spectrum_datasets.py b/depthcharge/data/spectrum_datasets.py index c394332..1944ea9 100644 --- a/depthcharge/data/spectrum_datasets.py +++ b/depthcharge/data/spectrum_datasets.py @@ -367,7 +367,9 @@ def _to_tensor( """ batch = super()._to_tensor(batch) batch[self.annotations] = self.tokenizer.tokenize( - batch[self.annotations] + batch[self.annotations], + add_start=self.tokenizer.start_token is not None, + add_stop=self.tokenizer.stop_token is not None, ) return batch diff --git a/depthcharge/tokenizers/peptides.py b/depthcharge/tokenizers/peptides.py index 18254e3..9207f81 100644 --- a/depthcharge/tokenizers/peptides.py +++ b/depthcharge/tokenizers/peptides.py @@ -164,12 +164,50 @@ def split(self, sequence: str) -> list[str]: return pep + def detokenize( + self, + tokens: torch.Tensor, + join: bool = True, + trim_start_token: bool = True, + trim_stop_token: bool = True, + ) -> list[str] | list[list[str]]: + """Retreive sequences from tokens. + + Parameters + ---------- + tokens : torch.Tensor of shape (n_sequences, max_length) + The zero-padded tensor of integerized tokens to decode. + join : bool, optional + Join tokens into strings? + trim_start_token : bool, optional + Remove the start token from the beginning of a sequence. + trim_stop_token : bool, optional + Remove the stop token from the end of a sequence. + + Returns + ------- + list[str] or list[list[str]] + The decoded sequences each as a string or list or strings. + + """ + decoded = super().detokenize( + tokens=tokens, + join=join, + trim_start_token=trim_start_token, + trim_stop_token=trim_start_token, + ) + + if self.reverse: + decoded = [d[::-1] for d in decoded] + + return decoded + @classmethod def from_proforma( cls, sequences: Iterable[str], replace_isoleucine_with_leucine: bool = False, - reverse: bool = True, + reverse: bool = False, start_token: str | None = None, stop_token: str | None = "$", ) -> PeptideTokenizer: @@ -238,7 +276,7 @@ def from_proforma( @staticmethod def from_massivekb( replace_isoleucine_with_leucine: bool = False, - reverse: bool = True, + reverse: bool = False, start_token: str | None = None, stop_token: str | None = "$", ) -> MskbPeptideTokenizer: diff --git a/tests/unit_tests/test_data/test_datasets.py b/tests/unit_tests/test_data/test_datasets.py index 49b07e7..f9b7237 100644 --- a/tests/unit_tests/test_data/test_datasets.py +++ b/tests/unit_tests/test_data/test_datasets.py @@ -73,14 +73,18 @@ def test_indexing(tokenizer, mgf_small, tmp_path): 1, 14, ) - torch.testing.assert_close(spec["seq"], tokenizer.tokenize(["LESLIEK"])) + torch.testing.assert_close( + spec["seq"], tokenizer.tokenize(["LESLIEK"], add_stop=True) + ) spec2 = dataset[3] assert spec2["mz_array"].shape == ( 1, 24, ) - torch.testing.assert_close(spec2["seq"], tokenizer.tokenize(["EDITHR"])) + torch.testing.assert_close( + spec2["seq"], tokenizer.tokenize(["EDITHR"], add_stop=True) + ) def test_load(tokenizer, tmp_path, mgf_small): @@ -106,11 +110,15 @@ def test_load(tokenizer, tmp_path, mgf_small): spec = dataset[0] assert len(spec) == 8 assert spec["mz_array"].shape == (1, 14) - torch.testing.assert_close(spec["seq"], tokenizer.tokenize(["LESLIEK"])) + torch.testing.assert_close( + spec["seq"], tokenizer.tokenize(["LESLIEK"], add_stop=True) + ) spec2 = dataset[1] assert spec2["mz_array"].shape == (1, 24) - torch.testing.assert_close(spec2["seq"], tokenizer.tokenize(["EDITHR"])) + torch.testing.assert_close( + spec2["seq"], tokenizer.tokenize(["EDITHR"], add_stop=True) + ) dataset = SpectrumDataset.from_lance(db_path, 1) spec = dataset[0] diff --git a/tests/unit_tests/test_data/test_loaders.py b/tests/unit_tests/test_data/test_loaders.py index ea7bb02..b692ec9 100644 --- a/tests/unit_tests/test_data/test_loaders.py +++ b/tests/unit_tests/test_data/test_loaders.py @@ -74,7 +74,7 @@ def test_ann_spectrum_loader(mgf_small): assert isinstance(batch["mz_array"], torch.Tensor) torch.testing.assert_close( batch["seq"][0, ...], - tokenizer.tokenize(["LESLIEK"]), + tokenizer.tokenize(["LESLIEK"], add_stop=True), ) diff --git a/tests/unit_tests/test_tokenizers/test_peptides.py b/tests/unit_tests/test_tokenizers/test_peptides.py index b5f58e9..6e4a66f 100644 --- a/tests/unit_tests/test_tokenizers/test_peptides.py +++ b/tests/unit_tests/test_tokenizers/test_peptides.py @@ -62,7 +62,7 @@ def test_proforma_init(): assert tokens == list("KEILSEL") tokens = proforma.tokenize(["LESLIEK"]) orig = proforma.detokenize(tokens) - assert orig == ["KEILSEL"] + assert orig == ["LESLIEK"] tokens = proforma.tokenize("LESLIEK", True, True, True)[0] assert "".join(tokens) == "KEILSEL$"