Skip to content

Commit

Permalink
🚨 🚨 🚨 Fix Issue 15003: SentencePiece Tokenizers Not Adding Special To…
Browse files Browse the repository at this point in the history
…kens in `convert_tokens_to_string` (huggingface#15775)

* Add test for SentencePiece not adding special tokens to strings

* Add SentencePieceStringConversionMixin to fix issue 15003

* Fix conversion from tokens to string for most SentencePiece tokenizers

Tokenizers fixed:
- AlbertTokenizer
- BarthezTokenizer
- CamembertTokenizer
- FNetTokenizer
- M2M100Tokenizer
- MBart50Tokenizer
- PegasusTokenizer
- Speech2TextTokenizer

* Fix MarianTokenizer, adjust SentencePiece test to accomodate vocab

* Fix DebertaV2Tokenizer

* Ignore LayoutXLMTokenizer in SentencePiece string conversion test

* Run 'make style' and 'make quality'

* Clean convert_tokens_to_string test

Instead of explicitly ignoring LayoutXLMTokenizer in the test,
override the test in LayoutLMTokenizationTest and do nothing in it.

* Remove commented out code

* Improve robustness of convert_tokens_to_string test

Instead of comparing lengths of re-tokenized text and input_ids,
check that converting all special tokens to string yields a string
with all special tokens.

* Inline and remove SentencePieceStringConversionMixin

The convert_tokens_to_string method is now implemented
in each relevant SentencePiece tokenizer.

* Run 'make style' and 'make quality'

* Revert removal of space in convert_tokens_to_string

* Remove redundant import

* Revert test text to original

* Uncomment the lowercasing of the reverse_text variable

* Mimic Rust tokenizer behavior for tokenizers

- Albert
- Barthez
- Camembert
- MBart50
- T5

* Fix accidentally skipping test in wrong tokenizer

* Add test for equivalent Rust and slow tokenizer behavior

* Override _decode in BigBirdTokenizer to mimic Rust behavior

* Override _decode in FNetTokenizer to mimic Rust behavior

* Override _decode in XLNetTokenizer to mimic Rust behavior

* Remove unused 're' import

* Update DebertaV2Tokenizer to mimic Rust tokenizer

* Deberta tokenizer now behaves like Albert and its `convert_tokens_to_string` is not tested.

* Ignore problematic tests in Deberta V2

* Add comment on why the Deberta V2 tests are skipped
  • Loading branch information
beneyal authored and Magnus Pierrau committed Dec 15, 2022
1 parent ef662f3 commit 19c0ae2
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 40 deletions.
18 changes: 17 additions & 1 deletion src/transformers/models/albert/tokenization_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,23 @@ def _convert_id_to_token(self, index):
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/models/barthez/tokenization_barthez.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ def _convert_id_to_token(self, index):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
Expand All @@ -278,10 +297,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,17 @@ def _convert_id_to_token(self, index):

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
62 changes: 60 additions & 2 deletions src/transformers/models/big_bird/tokenization_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import os
import re
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -182,8 +183,65 @@ def _convert_id_to_token(self, index):

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))

# Mimic the behavior of the Rust tokenizer:
# No space before [MASK] and [SEP]
if spaces_between_special_tokens:
text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)

if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/models/camembert/tokenization_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ def _convert_id_to_token(self, index):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
Expand All @@ -276,10 +295,6 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Expand Down
26 changes: 23 additions & 3 deletions src/transformers/models/deberta_v2/tokenization_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs)
self._tokenizer = SPMTokenizer(
vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
)

@property
def vocab_size(self):
Expand Down Expand Up @@ -291,7 +293,9 @@ class SPMTokenizer:
BPE-dropout.
"""

def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None):
def __init__(
self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
):
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
Expand All @@ -312,6 +316,7 @@ def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[D
# self.vocab['[UNK]'] = 3

self.spm = spm
self.special_tokens = special_tokens

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -339,7 +344,22 @@ def convert_ids_to_tokens(self, ids):

def decode(self, tokens, start=-1, end=-1, raw_text=None):
if raw_text is None:
return self.spm.decode_pieces([t for t in tokens])
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.spm.decode_pieces(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.spm.decode_pieces(current_sub_tokens)
return out_string.strip()
else:
words = self.split_to_words(raw_text)
word_tokens = [self.tokenize(w) for w in words]
Expand Down
62 changes: 61 additions & 1 deletion src/transformers/models/fnet/tokenization_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Tokenization classes for FNet model."""

import os
import re
import unicodedata
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -213,7 +214,66 @@ def _convert_id_to_token(self, index):
return self.sp_model.IdToPiece(index)

def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))

# Mimic the behavior of the Rust tokenizer:
# No space after <unk>
if spaces_between_special_tokens:
text = re.sub(r"(<unk>) ", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)

if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,19 @@ def _convert_id_to_token(self, index: int) -> str:
return self.id_to_lang_token[index]
return self.decoder.get(index, self.unk_token)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/marian/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,18 @@ def decode(self, token_ids, **kwargs):

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens)
else:
return self.spm_target.DecodePieces(tokens)
sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
Expand Down
21 changes: 18 additions & 3 deletions src/transformers/models/mbart50/tokenization_mbart50.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,24 @@ def _convert_id_to_token(self, index: int) -> str:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
Expand Down
Loading

0 comments on commit 19c0ae2

Please sign in to comment.