Skip to content

Commit

Permalink
[LlamaTokenizerFast] nit update post_processor on the fly (huggingf…
Browse files Browse the repository at this point in the history
…ace#23855)

* Update the processor when changing add_eos and add_bos

* fixup

* update

* add a test

* fix failing tests

* fixup
  • Loading branch information
ArthurZucker authored and sheonhan committed Jun 1, 2023
1 parent e27ae62 commit ca330a1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from shutil import copyfile
from typing import Optional, Tuple

from tokenizers import processors

from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
from ...utils.versions import require_version
Expand Down Expand Up @@ -84,6 +86,8 @@ def __init__(
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
add_bos_token=True,
add_eos_token=False,
**kwargs,
):
super().__init__(
Expand All @@ -95,10 +99,50 @@ def __init__(
eos_token=eos_token,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()

self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True

def update_post_processor(self):
bos = self.bos_token
bos_token_id = self.bos_token_id

eos = self.eos_token
eos_token_id = self.eos_token_id

single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"

special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)

@property
def add_eos_token(self):
return self._add_eos_token

@property
def add_bos_token(self):
return self._add_bos_token

@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()

@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,39 @@ def integration_tests(self):
},
)

def test_fast_special_tokens(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243]

fast_tokenizer.add_eos_token = False
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243]

fast_tokenizer.add_eos_token = True
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243, 2]

slow_tokenizer.add_eos_token = True
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243, 2]

fast_tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [319, 4559, 1243, 2]

slow_tokenzier = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
assert slow == [319, 4559, 1243, 2]

self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False

@slow
def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge
Expand Down

0 comments on commit ca330a1

Please sign in to comment.