-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LlamaTokenizerFast
wrong word_id
references based on batch encoding
#29617
Comments
Alright this is a valid bug! Pretty sure this come from the lack of a |
Here is what I am getting with the In [1]: from transformers import AutoTokenizer
...:
...: # any llama (1) based tokenizer
...: llama_tokenizer = AutoTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer',
...: from_slow=True, legacy=False)
...:
...: sentence = 'Sample to demonstrate the issue'
...:
...: input_encoding = llama_tokenizer(sentence, add_special_tokens=True)
...: tokens = input_encoding.tokens()
...: word_ids = input_encoding.word_ids()
...:
...: print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}')
Tokens produced: ['<s>', '▁S', 'amp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
Referenced word ids: [None, 0, 0, 0, 1, 2, 3, 4] |
Thank you, this is looking good! |
Hi @ArthurZucker, I am still facing the same issue even tried your code. I tried to use your PD, but it seems to be merged to the main. |
You need to make sure you are using the correct |
I am using Besides, could you please how to disable |
It is the correct version. Setting |
from tokenizers import Tokenizer
llama_tokenizer = Tokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
llama_tokenizer._tokenizer.pre_tokenizer.split = False This raises error: from tokenizers import Tokenizer
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', from_slow=True, legacy=False)
tok._tokenizer.pre_tokenizer.split = False This raises error: Could you please help? |
Before I start, I use the most recent tokenizers version in main branch and the transformers version in the PR #28881(with a slight modification locally on my end) Hey @phusroyal, I've just checked it myself again and you don't need to set from tokenizers import Tokenizer
# any llama (1) based tokenizer
llama_tokenizer = Tokenizer.from_pretrained('hf-internal-testing/llama-new-metaspace')
# just for demonstration purpose, it's also here
#llama_tokenizer.pre_tokenizer.split = False
sentence = ['Sampel to demonstrate the issue']
input_encoding = llama_tokenizer.encode_batch(sentence, add_special_tokens=True)
tokens = input_encoding[0].tokens
word_ids = input_encoding[0].word_ids
print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}') For the transformers variant, you would also not want to modify the |
Ah and another small thing, this issue is related to llama1 tokenizers. Not sure how much it relates to llama2 tokenizers. |
Great! thank you for your help. And yes, I also face the same issue in llama2 tokenizers. |
I see. I don't think llama2 has a pretokenizer in the current implementation of Tokenizers. But, when testing it with Regarding tokenizers, it seems the pretokenizer isn't set in that case which causes errors. Might be a separate bug/issue. |
The Llama2 and 1 don't have a pre_tokenizer, but there are issues with the normalizer, thus the fix |
I'll get the PR ready in a bit |
I ended up here from a separate problem with Gemma and pre-tokenizers not working at all. But this codebase has so much debt, and it is so incredibly messy, guys... |
Could you share a reproducer? |
@ArthurZucker, it's the same issue as here with llama. Since it's gated I've used the from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('google/gemma-7b', use_auth=True)
sentence = 'Sampel to demonstrate the issue'
input_encoding = tokenizer(sentence, add_special_tokens=True)
tokens = input_encoding.tokens()
word_ids = input_encoding.word_ids()
print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}') Results in: Also tried with the versions above (i.e. the PR #28881 branch + the tokenizers |
Yep, it's the same issue here, the >>> from tokenizers import pre_tokenizers, normalizers
>>> tokenizer._tokenizer.normalizer = normalizers.Sequence([])
>>> tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(prepend_scheme="never", split=True)
>>> input_encoding = tokenizer(sentence, add_special_tokens=True)
>>> input_encoding.tokens()
['<bos>', 'Samp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
input_encoding.word_ids() |
This was not really possible before without adding extra bugs, but should work now |
Also now that |
Thank you 🙏 is there another separate issue open for gemma or should I (or someone else) open a separate issue? |
Mmm feel free to open one, I'll fix it in the same PR ! |
Hi all, It seems this issue was fixed so perhaps I'm doing something wrong, but the issue persists whenever I try to retrieve the word_ids for a sequence with the Llama tokenizer. For example below,
still returns the following
I deleted and redownloaded the
|
Hey! You are right, it is not fixed but I recommend you to use: In [30]: tokenizer("Hi my name is Adam").encodings[0].offsets
do_tokenizer None
Out[30]: [(0, 0), (0, 2), (2, 5), (5, 10), (10, 13), (13, 18)] because the concept of
The I recommend you to use tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split(Regex('(?<!▁)▁'), "merged_with_next") (on top of metaspace) |
System Info
transformers
version: 4.39.0.dev0 (commit hash 923733c)Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Prints:
Tokens produced: ['<s>', '▁S', 'amp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
Referenced word ids: [None, 0, 0, 0, 0, 0, 0, 0]
Expected behavior
Given a fast tokenizer, batch encodings should be able to return the tokenized tokens as well as the corresponding
word_ids
. The firstNone
is to be expected though (see this example).The text was updated successfully, but these errors were encountered: