Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaTokenizerFast wrong word_id references based on batch encoding #29617

Closed
2 of 4 tasks
vasqu opened this issue Mar 12, 2024 · 24 comments · Fixed by #28881
Closed
2 of 4 tasks

LlamaTokenizerFast wrong word_id references based on batch encoding #29617

vasqu opened this issue Mar 12, 2024 · 24 comments · Fixed by #28881

Comments

@vasqu
Copy link
Contributor

vasqu commented Mar 12, 2024

System Info

  • transformers version: 4.39.0.dev0 (commit hash 923733c)
  • Platform: Linux-6.5.0-10022-tuxedo-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer

# any llama (1) based tokenizer
llama_tokenizer = AutoTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')

sentence = 'Sampel to demonstrate the issue'

input_encoding = llama_tokenizer(sentence, add_special_tokens=True)
tokens = input_encoding.tokens()
word_ids = input_encoding.word_ids()

print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}')

Prints:
Tokens produced: ['<s>', '▁S', 'amp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
Referenced word ids: [None, 0, 0, 0, 0, 0, 0, 0]

Expected behavior

Given a fast tokenizer, batch encodings should be able to return the tokenized tokens as well as the corresponding word_ids. The first None is to be expected though (see this example).

@ArthurZucker
Copy link
Collaborator

Alright this is a valid bug! Pretty sure this come from the lack of a pre_tokenizer for llama and shall thus be fixed by #28881

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 25, 2024

Here is what I am getting with the tokenizers PR of course

In [1]: from transformers import AutoTokenizer
   ...: 
   ...: # any llama (1) based tokenizer
   ...: llama_tokenizer = AutoTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer',
   ...: from_slow=True, legacy=False)
   ...: 
   ...: sentence = 'Sample to demonstrate the issue'
   ...: 
   ...: input_encoding = llama_tokenizer(sentence, add_special_tokens=True)
   ...: tokens = input_encoding.tokens()
   ...: word_ids = input_encoding.word_ids()
   ...: 
   ...: print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}')
Tokens produced: ['<s>', '▁S', 'amp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
Referenced word ids: [None, 0, 0, 0, 1, 2, 3, 4]

@vasqu
Copy link
Contributor Author

vasqu commented Mar 25, 2024

Thank you, this is looking good!

@phusroyal
Copy link

phusroyal commented Apr 3, 2024

Hi @ArthurZucker, I am still facing the same issue even tried your code.

I tried to use your PD, but it seems to be merged to the main.

@ArthurZucker
Copy link
Collaborator

You need to make sure you are using the correct tokenizers version. Note that you should also disable split for Metaspace

@phusroyal
Copy link

You need to make sure you are using the correct tokenizers version. Note that you should also disable split for Metaspace

I am using tokenizers version 0.16.0-dev.0. Is this the correct version?

Besides, could you please how to disable split for Metaspace?

@ArthurZucker
Copy link
Collaborator

It is the correct version. Setting tokenizer._tokenizer.pre_tokenizer.split = False

@phusroyal
Copy link

from tokenizers import Tokenizer
llama_tokenizer = Tokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
llama_tokenizer._tokenizer.pre_tokenizer.split = False

This raises error: AttributeError: 'tokenizers.Tokenizer' object has no attribute '_tokenizer'. Did you mean: 'pre_tokenizer'?. And if I replace the row with llama_tokenizer.pre_tokenizer.split, it returns AttributeError: 'NoneType' object has no attribute 'split'

from tokenizers import Tokenizer
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', from_slow=True, legacy=False)
tok._tokenizer.pre_tokenizer.split = False

This raises error: AttributeError: 'NoneType' object has no attribute 'split'

Could you please help?

@vasqu
Copy link
Contributor Author

vasqu commented Apr 3, 2024

Before I start, I use the most recent tokenizers version in main branch and the transformers version in the PR #28881(with a slight modification locally on my end)

Hey @phusroyal, I've just checked it myself again and you don't need to set split=False. In my case, it would cause unwanted behaviour. In the tokenizers library, if you need it in the future, you would access it via tokenizer.pre_tokenizer.split = False I think.

from tokenizers import Tokenizer

# any llama (1) based tokenizer
llama_tokenizer = Tokenizer.from_pretrained('hf-internal-testing/llama-new-metaspace')
# just for demonstration purpose, it's also here
#llama_tokenizer.pre_tokenizer.split = False

sentence = ['Sampel to demonstrate the issue']

input_encoding = llama_tokenizer.encode_batch(sentence, add_special_tokens=True)
tokens = input_encoding[0].tokens
word_ids = input_encoding[0].word_ids

print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}')

For the transformers variant, you would also not want to modify the split attribute (albeit you'd do it like tokenizer._tokenizer.pre_tokenizer.split = False like @ArthurZucker showed). Although, I encountered an issue when initializing the tokenizer: add_prefix_space wasn't recognized as a keywordarg anymore (see this line). Not sure if it's because I'm messing in dev versions or if the metaspace refactor caused some issues there.

@vasqu
Copy link
Contributor Author

vasqu commented Apr 3, 2024

Ah and another small thing, this issue is related to llama1 tokenizers. Not sure how much it relates to llama2 tokenizers.

@phusroyal
Copy link

Great! thank you for your help.

And yes, I also face the same issue in llama2 tokenizers.

@vasqu
Copy link
Contributor Author

vasqu commented Apr 4, 2024

I see. I don't think llama2 has a pretokenizer in the current implementation of Tokenizers. But, when testing it with meta-llama/Llama-2-7b-hf, the same transformers settings worked for me, i.e. AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', from_slow=True, legacy=False).

Regarding tokenizers, it seems the pretokenizer isn't set in that case which causes errors. Might be a separate bug/issue.

@ArthurZucker
Copy link
Collaborator

The Llama2 and 1 don't have a pre_tokenizer, but there are issues with the normalizer, thus the fix

@ArthurZucker
Copy link
Collaborator

I'll get the PR ready in a bit

@epignatelli
Copy link

epignatelli commented Apr 16, 2024

I ended up here from a separate problem with Gemma and pre-tokenizers not working at all.
I really appreciate all the hard work hf is putting into transformers, we wouldn't be here without you and thanks for it.

But this codebase has so much debt, and it is so incredibly messy, guys...

@ArthurZucker
Copy link
Collaborator

Could you share a reproducer?
Gemma tokenizer is compeltely different

@vasqu
Copy link
Contributor Author

vasqu commented Apr 19, 2024

@ArthurZucker, it's the same issue as here with llama. Since it's gated I've used the huggingface-cli login beforehand. Transformers version is the most recent one 4.40.0.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('google/gemma-7b', use_auth=True)
sentence = 'Sampel to demonstrate the issue'

input_encoding = tokenizer(sentence, add_special_tokens=True)
tokens = input_encoding.tokens()
word_ids = input_encoding.word_ids()

print(f'Tokens produced: {tokens}\nReferenced word ids: {word_ids}')

Results in:
Tokens produced: ['<bos>', 'Samp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
Referenced word ids: [None, 0, 0, 0, 0, 0, 0]

Also tried with the versions above (i.e. the PR #28881 branch + the tokenizers 0.16.0-dev.0 version) using legacy=False and from_slow=True to no avail. I know it's not meant to fix it, just gave it a try.

@ArthurZucker
Copy link
Collaborator

Yep, it's the same issue here, the normalizer does not split the input on spaces like the pre_tokenizer does. Will see what I can do, basically this should work:

>>> from tokenizers import pre_tokenizers, normalizers
>>> tokenizer._tokenizer.normalizer = normalizers.Sequence([])
>>> tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(prepend_scheme="never", split=True)
>>> input_encoding = tokenizer(sentence, add_special_tokens=True)
>>> input_encoding.tokens()
['<bos>', 'Samp', 'el', '▁to', '▁demonstrate', '▁the', '▁issue']
input_encoding.word_ids()

@ArthurZucker
Copy link
Collaborator

This was not really possible before without adding extra bugs, but should work now

@ArthurZucker
Copy link
Collaborator

Also now that tokenizers has been updated I'll officially fix the llama tokenizer

@vasqu
Copy link
Contributor Author

vasqu commented Apr 22, 2024

Thank you 🙏 is there another separate issue open for gemma or should I (or someone else) open a separate issue?

@ArthurZucker
Copy link
Collaborator

Mmm feel free to open one, I'll fix it in the same PR !

@kooryan
Copy link

kooryan commented Jul 31, 2024

Hi all,

It seems this issue was fixed so perhaps I'm doing something wrong, but the issue persists whenever I try to retrieve the word_ids for a sequence with the Llama tokenizer. For example below,

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf') 
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = tokenizer.eos_token
encoded = tokenizer("Hi my name is Adam")
print(encoded.tokens)
print("WORD IDS", encoded.word_ids())

still returns the following

['<s>', '▁Hi', '▁my', '▁name', '▁is', '▁Adam']
WORD IDS [None, 0, 0, 0, 0, 0]

I deleted and redownloaded the transformers package as well and my versions from pip list look like below. Is anyone still able to replicate this issue?

tokenizers               0.19.1
torch                    2.0.1
transformers             4.43.3

@ArthurZucker
Copy link
Collaborator

Hey! You are right, it is not fixed but I recommend you to use:

In [30]: tokenizer("Hi my name is Adam").encodings[0].offsets
do_tokenizer None
Out[30]: [(0, 0), (0, 2), (2, 5), (5, 10), (10, 13), (13, 18)]

because the concept of word is very arbitrary. The main reason why it works for gpt2 for example is because the Split pre tokenizer splits the string into "words". But since we are not doing that, only merging what is mergeable.
For example:

print(tokenizer("Hi my name is   Adam. we'll").encodings[0].tokens)
['<s>', '▁Hi', '▁my', '▁name', '▁is', '▁▁', '▁Adam', '.', '▁we', "'", 'll']

The '▁▁' is a word? Or does it belong to '▁▁' by itself? What about ' and ll ?

I recommend you to use pre_tokenizers.Split(Regex('(?<!▁)▁'), "merged_with_next") for example:

tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split(Regex('(?<!▁)▁'), "merged_with_next") (on top of metaspace)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants