-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mistral Tokenizer.decode() add a space when use_fast=True #29452
Comments
Hey! Thanks for reporting. In [4]: slow.convert_ids_to_tokens([1, 12014])
Out[4]: ['<s>', '▁hi'] This issue emerges because In [14]: slow.convert_tokens_to_string(['<s>', '▁hi'])
Out[14]: '<s>hi' def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
# since we manually add the prefix space, we have to remove it when decoding
if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
tokens[0] = tokens[0][1:]
current_sub_tokens = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0 and self.legacy:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
+ if prev_is_special and i==1 and self.add_prefix_space:
+ out_string += " "
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string this should fix it. In [4]: slow.convert_tokens_to_string(['▁hi'])
Out[4]: 'hi' |
Thanks for your reply! Here is an example: from transformers import AutoTokenizer
model_path="mistralai/Mistral-7B-v0.1"
fast = AutoTokenizer.from_pretrained(model_path, use_fast=True, add_bos_token=False)
slow = AutoTokenizer.from_pretrained(model_path, use_fast=False, add_bos_token=False)
text = "<s>user: hi</s>assistant: hello</s>"
print(f"text={text}")
print(f"fast tokenize={fast.encode(text)}")
print(f"slow tokenize={slow.encode(text)}")
fast_decode_text = fast.decode(fast.encode(text))
slow_decode_text = slow.decode(slow.encode(text))
print(f"fast decode text={fast_decode_text}")
print(f"slow decode text={slow_decode_text}") output:
fast/slow decode text add space after special tokens. Is there any parameter that can be controlled not to add space after special tokens? |
That is expected 😉 slow = AutoTokenizer.from_pretrained(model_path, use_fast=False, add_bos_token=False, legacy=False) |
check the doc for this argument 🤗 |
Thank you for your prompt response. But when use_fast=True, space will still appear after sp tokens. Here is an example: from transformers import AutoTokenizer
model_path="mistralai/Mistral-7B-v0.1"
fast = AutoTokenizer.from_pretrained(model_path, use_fast=True, add_bos_token=False, legacy=False)
slow = AutoTokenizer.from_pretrained(model_path, use_fast=False, add_bos_token=False, legacy=False)
text = "<s>user: hi</s>assistant: hello</s>"
print(f"text={text}")
print(f"fast tokenize={fast.encode(text)}")
print(f"slow tokenize={slow.encode(text)}")
fast_decode_text = fast.decode(fast.encode(text))
slow_decode_text = slow.decode(slow.encode(text))
# a space appear between "</s>" and "assistant"
print(f"fast decode text={fast_decode_text}")
print(f"slow decode text={slow_decode_text}") output is
Also, I found that there is a difference between fast and slow encoding when legacy=False, which one does mistralai/Mistral-7B-V0.1 use? |
Yes, as I said this is being fixed by #28881 but it is not in main yet. |
output is
fast and slow tokenizer have same encode result, but fast add a space after "<s>" when decode
I also noticed issue huggingface/tokenizers#1448 , @ArthurZucker said "use metaspace with prepend_scheme="first" and no normalizer", this already exists in transformers version 4.38.2 and doesn't seem to work.
Are there any useful info I missed? how can I delete the space after "<s>" when use_fast=True?
Thank you to those who have contributed to the Transformers lib.
The text was updated successfully, but these errors were encountered: