Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 🚨 🚨 Fix Issue 15003: SentencePiece Tokenizers Not Adding Special Tokens in convert_tokens_to_string #15775

Merged
merged 29 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c0c15bb
Add test for SentencePiece not adding special tokens to strings
beneyal Feb 22, 2022
08d47a7
Add SentencePieceStringConversionMixin to fix issue 15003
beneyal Feb 22, 2022
6c82c09
Fix conversion from tokens to string for most SentencePiece tokenizers
beneyal Feb 22, 2022
7bfe422
Fix MarianTokenizer, adjust SentencePiece test to accomodate vocab
beneyal Feb 22, 2022
cb4c824
Fix DebertaV2Tokenizer
beneyal Feb 22, 2022
e6795b9
Ignore LayoutXLMTokenizer in SentencePiece string conversion test
beneyal Feb 22, 2022
17e0921
Run 'make style' and 'make quality'
beneyal Feb 22, 2022
c29bcdd
Clean convert_tokens_to_string test
beneyal Feb 24, 2022
e11cf2a
Remove commented out code
beneyal Feb 24, 2022
fb1c273
Improve robustness of convert_tokens_to_string test
beneyal Feb 24, 2022
91413e5
Inline and remove SentencePieceStringConversionMixin
beneyal Feb 24, 2022
0743ae0
Run 'make style' and 'make quality'
beneyal Feb 24, 2022
bad0f43
Revert removal of space in convert_tokens_to_string
beneyal Feb 25, 2022
8cb264e
Remove redundant import
beneyal Feb 25, 2022
c14ebcb
Revert test text to original
beneyal Feb 25, 2022
f021c2d
Uncomment the lowercasing of the reverse_text variable
beneyal Feb 25, 2022
cee809c
Mimic Rust tokenizer behavior for tokenizers
beneyal Mar 4, 2022
adc06ff
Fix accidentally skipping test in wrong tokenizer
beneyal Mar 4, 2022
6b3cd77
Add test for equivalent Rust and slow tokenizer behavior
beneyal Mar 4, 2022
e94d85e
Override _decode in BigBirdTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
3273b1a
Override _decode in FNetTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
a89dba6
Override _decode in XLNetTokenizer to mimic Rust behavior
beneyal Mar 4, 2022
cb53d92
Merge 'main' into the 15003 fix branch
beneyal Apr 28, 2022
205e133
Remove unused 're' import
beneyal Apr 28, 2022
c271a63
Update DebertaV2Tokenizer to mimic Rust tokenizer
beneyal Apr 29, 2022
0a8c149
Merge branch 'main' of https://github.com/huggingface/transformers in…
beneyal Nov 2, 2022
073b749
Deberta tokenizer now behaves like Albert and its `convert_tokens_to_…
beneyal Nov 2, 2022
048bc65
Ignore problematic tests in Deberta V2
beneyal Nov 2, 2022
21c3f6d
Add comment on why the Deberta V2 tests are skipped
beneyal Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/transformers/models/albert/tokenization_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import sentencepiece as spm

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand Down Expand Up @@ -56,7 +56,7 @@
SPIECE_UNDERLINE = "▁"


class AlbertTokenizer(PreTrainedTokenizer):
class AlbertTokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

Expand Down Expand Up @@ -249,9 +249,6 @@ def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/barthez/tokenization_barthez.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sentencepiece as spm

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand All @@ -46,7 +46,7 @@
SPIECE_UNDERLINE = "▁"


class BarthezTokenizer(PreTrainedTokenizer):
class BarthezTokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
[SentencePiece](https://github.com/google/sentencepiece).
Expand Down Expand Up @@ -276,10 +276,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/camembert/tokenization_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sentencepiece as spm

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand All @@ -42,7 +42,7 @@
SPIECE_UNDERLINE = "▁"


class CamembertTokenizer(PreTrainedTokenizer):
class CamembertTokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Construct a CamemBERT tokenizer. Based on
[SentencePiece](https://github.com/google/sentencepiece).
Expand Down Expand Up @@ -276,10 +276,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
21 changes: 18 additions & 3 deletions src/transformers/models/deberta_v2/tokenization_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def __init__(
)
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs)
self._tokenizer = SPMTokenizer(
vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
)

@property
def vocab_size(self):
Expand Down Expand Up @@ -287,7 +289,9 @@ class SPMTokenizer:
BPE-dropout.
"""

def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None):
def __init__(
self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
):
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
Expand All @@ -307,6 +311,7 @@ def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[D
# self.vocab['[UNK]'] = 3

self.spm = spm
self.special_tokens = special_tokens

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -343,7 +348,17 @@ def convert_ids_to_tokens(self, ids):

def decode(self, tokens, start=-1, end=-1, raw_text=None):
if raw_text is None:
return self.spm.decode_pieces([t for t in tokens])
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.special_tokens:
out_string += self.spm.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.spm.decode_pieces(current_sub_tokens)
return out_string.strip()
else:
words = self.split_to_words(raw_text)
word_tokens = [self.tokenize(w) for w in words]
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/fnet/tokenization_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sentencepiece as spm

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand All @@ -43,7 +43,7 @@
SPIECE_UNDERLINE = "▁"


class FNetTokenizer(PreTrainedTokenizer):
class FNetTokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Construct an FNet tokenizer. Adapted from [`AlbertTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece). This tokenizer inherits from [`PreTrainedTokenizer`]
Expand Down Expand Up @@ -212,9 +212,6 @@ def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sentencepiece

from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand Down Expand Up @@ -62,7 +62,7 @@
# fmt: on


class M2M100Tokenizer(PreTrainedTokenizer):
class M2M100Tokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Construct an M2M100 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

Expand Down Expand Up @@ -221,10 +221,6 @@ def _convert_id_to_token(self, index: int) -> str:
return self.id_to_lang_token[index]
return self.decoder.get(index, self.unk_token)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/marian/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,18 @@ def decode(self, token_ids, **kwargs):

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens)
else:
return self.spm_target.DecodePieces(tokens)
sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/mbart50/tokenization_mbart50.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import sentencepiece as spm

from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand All @@ -45,7 +45,7 @@
# fmt: on


class MBart50Tokenizer(PreTrainedTokenizer):
class MBart50Tokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
"""
Construct a MBart50 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

Expand Down Expand Up @@ -233,9 +233,9 @@ def _convert_id_to_token(self, index: int) -> str:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
# def convert_tokens_to_string(self, tokens: List[str]) -> str:
# """Converts a sequence of tokens (strings for sub-words) in a single string."""
# return self.sp_model.decode(tokens)
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/pegasus/tokenization_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import sentencepiece as spm

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import PreTrainedTokenizer, SentencePieceStringConversionMixin
from ...utils import logging


Expand All @@ -38,7 +38,7 @@
logger = logging.get_logger(__name__)


class PegasusTokenizer(PreTrainedTokenizer):
class PegasusTokenizer(SentencePieceStringConversionMixin, PreTrainedTokenizer):
r"""
Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

Expand Down Expand Up @@ -226,10 +226,10 @@ def _convert_id_to_token(self, index: int) -> str:
token = self.sp_model.IdToPiece(index - self.offset)
return token

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
# def convert_tokens_to_string(self, tokens):
# """Converts a sequence of tokens (string) in a single string."""
# out_string = self.sp_model.decode_pieces(tokens)
# return out_string
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

def num_special_tokens_to_add(self, pair=False):
"""Just EOS"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,19 @@ def _convert_id_to_token(self, index: int) -> str:

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = self.sp_model.decode(tokens)

if self.do_upper_case:
out_string = out_string.upper()
return out_string
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
decoded = self.sp_model.decode(current_sub_tokens)
out_string += (decoded.upper() if self.do_upper_case else decoded) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
decoded = self.sp_model.decode(current_sub_tokens)
out_string += decoded.upper() if self.do_upper_case else decoded
return out_string.strip()

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,19 @@ def _decode(
return clean_text
else:
return text


class SentencePieceStringConversionMixin:
SaulLu marked this conversation as resolved.
Show resolved Hide resolved
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token + " "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to my previous comment, but I'm not sure what the behavior should be here to add spaces. How did you choose this strategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the exact code from PR #8435, I assumed this implementation would be relevant for all SPM tokenizers, but it may well be that this assumption is incorrect.

Copy link
Contributor Author

@beneyal beneyal Feb 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running the test you suggested, I do believe that this extra space is an error, at least compared to the Rust tokenizers. Removing the space passes the test 👍

EDIT:
I forgot to run the whole test suite for the T5Tokenizer and now I see there's a problem. If I keep the space, test_fast_and_slow_same_result passes but test_sentencepiece_tokenize_and_convert_tokens_to_string fails. If I remove the space, then it's vice-versa, and I'm not sure what the correct behavior should be.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't be able to look at this issue T5Tokenizer with the again until next Monday but I promise to come back to you early next week

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much! I will be waiting patiently 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them 😅 Once the testing finishes (currently 50%), I'll know exactly why. For context, I accepted the changes in main over what was in the file of my branch. Didn't do anything else with that file.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, after feeling dumb for not noticing the pytest -n auto argument, which is why the tests took almost 3 hours to run, I found the problem: when accepting the change, it removed a constructor argument I passed. I fixed that, and now only one test is failing:

image

This is weird, since using the "microsoft/deberta-v2-xlarge" tokenizer works on both slow and fast. I believe the problem is the test vocab. Is there something I can do to fix it, other than changing the test sentence (which seems quite drastic)?

Copy link
Contributor

@SaulLu SaulLu Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Can you try to change this line (in test_tokenization_deberta_v2.py) :

tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB)

by tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, unk_token="<unk>").

I think it's because the SAMPLE_VOCAB defines the unknown token as "<unk>" and not as "[UNK]" which is the default value for DebertaV2Tokenizer.

EDIT: I've inspected the content inside and it's probably even better to replace it by

tokenizer = DebertaV2Tokenizer(
    SAMPLE_VOCAB, 
        bos_token="[CLS]",
        eos_token="[SEP]",
        unk_token="<unk>",
        sep_token="[SEP]",
        pad_token="<pad>",
        cls_token="[CLS]",
        mask_token="[MASK]",
)

as it's how the special tokens are defined inside SAMPLE_VOCAB

Copy link
Contributor Author

@beneyal beneyal Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. This change breaks test_get_vocab, which makes sense, but test_sentencepiece_tokenize_and_decode at least now shows the "real" problem, which is what I fixed for the other slow tokenizers:

 AssertionError: '[CLS]<unk>his is text to test the tokenizer.[SEP]' != '[CLS] <unk>his is text to test the tokenizer.[SEP]'

It's the spacing thing. For some tokenizers, it was a simple fix, for others, it required overriding _decode, but I can fix this. I'll need to check how the fast tokenizer works to mimic it space-between-special-tokens-wise.

EDIT:
The test_get_vocab test is fixed, I just did tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, unk_token="<unk>").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SaulLu,

I made DeBERTa v2 act like the fast tokenizer. It's very ad-hoc, since the behavior of the special tokens is a bit weird, but it passes all the tests 🙂

Would love to hear your feedback!

current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
17 changes: 14 additions & 3 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AutoTokenizer,
BertTokenizer,
BertTokenizerFast,
LayoutXLMTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Expand Down Expand Up @@ -356,7 +357,7 @@ def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
return

tokenizer = self.get_tokenizer()
text = "This is text to test the tokenizer."
text = "This is a test"
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

if self.test_sentencepiece_ignore_case:
text = text.lower()
Expand All @@ -368,11 +369,21 @@ def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
# check if converting back to original text works
reverse_text = tokenizer.convert_tokens_to_string(tokens)

if self.test_sentencepiece_ignore_case:
reverse_text = reverse_text.lower()
# All tokenizers pass this test without the below commented out code.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# All tokenizers pass this test without the below commented out code.
# All tokenizers pass this test without the below commented out code.

To clean up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the comment 🙂

# if self.test_sentencepiece_ignore_case:
# reverse_text = reverse_text.lower()
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

self.assertEqual(reverse_text, text)

if isinstance(tokenizer, LayoutXLMTokenizer):
return
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

input_ids = tokenizer(text).input_ids
tokens_including_special = tokenizer.convert_ids_to_tokens(input_ids)
reverse_text = tokenizer.convert_tokens_to_string(tokens_including_special)

self.assertEqual(len(tokenizer.tokenize(reverse_text)), len(input_ids))
SaulLu marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't necessarily know how to handle in this convert_tokens_to_string function is how to re-add spaces around each special token. I wonder if we should be aligned with what the fast version of the tokenizer would outputs. If it's something that seems relevant to you, I would add this test:

Suggested change
if self.test_rust_tokenizer:
rust_tokenizer = self.get_rust_tokenizer()
special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens)
self.assertEqual(special_tokens_string, special_tokens_string_rust)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I didn't go for the "compare to Rust tokenizer" since not all tokenizers in question have this luxury. Thinking about it now that you mention it, this code will definitely add another layer of reliability to the existing test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, after adding this test, a few previously passing tests failed. This is because of the space you mentioned in another comment. I removed the space, so now the code looks like this:

def convert_tokens_to_string(self, tokens):
    ...
    out_string += self.sp_model.decode(current_sub_tokens) + token  # << a space used to be here
    ...

And now all the tests pass, including T5Tokenizer.
I will note that I changed calls to self.sp_model.decode_pieces to self.sp_model.decode, as decode, if needed, delegates to decode_pieces, and this makes the code uniform, in my opinion.

def test_subword_regularization_tokenizer(self) -> None:
if not self.test_sentencepiece:
return
Expand Down