Skip to content

Commit

Permalink
#32184 save total_vocab_size (#32240)
Browse files Browse the repository at this point in the history
* save total_vocab_size = vocab_size + user added tokens to speed up operation

* updating length when added_tokens_decoder is set

* add test len(tokenizer)
  • Loading branch information
itazap authored and nbroad1881 committed Aug 7, 2024
1 parent 3d581de commit acc920f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict

self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
self._added_tokens_encoder[str(token)] = index
self._update_total_vocab_size()

def get_added_vocab(self) -> Dict[str, int]:
"""
Expand All @@ -494,10 +495,17 @@ def get_added_vocab(self) -> Dict[str, int]:

def __len__(self):
"""
Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
there is a hole in the vocab, we will add tokenizers at a wrong index.
Size of the full vocabulary with the added tokens.
"""
return len(set(self.get_vocab().keys()))
return self.total_vocab_size

def _update_total_vocab_size(self):
"""
Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
is only updated when adding tokens.
"""
self.total_vocab_size = len(self.get_vocab())

def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Expand Down Expand Up @@ -574,6 +582,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
logger.info(f"Adding {token} to the vocabulary")

self._update_trie()
self._update_total_vocab_size()
return added_tokens

def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
Expand Down
12 changes: 12 additions & 0 deletions tests/tokenization/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,15 @@ def test_instantiation_from_tokenizers_json_file(self):
with tempfile.TemporaryDirectory() as tmpdirname:
bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json"))
PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json"))

def test_len_tokenizer(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
with self.subTest(f"{tokenizer_class}"):
tokenizer = tokenizer_class.from_pretrained("bert-base-uncased")
added_tokens_size = len(tokenizer.added_tokens_decoder)
self.assertEqual(len(tokenizer), tokenizer.vocab_size)

tokenizer.add_tokens(["<test_token>"])
self.assertEqual(len(tokenizer), tokenizer.vocab_size + 1)
self.assertEqual(len(tokenizer.added_tokens_decoder), added_tokens_size + 1)
self.assertEqual(len(tokenizer.added_tokens_encoder), added_tokens_size + 1)

0 comments on commit acc920f

Please sign in to comment.