Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_prefix_space won't be respected by Llama tokenizer #29625

Closed
2 of 4 tasks
scruel opened this issue Mar 13, 2024 · 19 comments · Fixed by #30964
Closed
2 of 4 tasks

add_prefix_space won't be respected by Llama tokenizer #29625

scruel opened this issue Mar 13, 2024 · 19 comments · Fixed by #30964

Comments

@scruel
Copy link
Contributor

scruel commented Mar 13, 2024

System Info

  • transformers version: 4.38.2
  • Platform: Linux-6.5.0-14-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.21.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

With sentencepiece==0.2.0 and protobuf==4.25.3 installed

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", local_files_only=True, add_prefix_space=False)
>>> tokenizer.tokenize("overheard")
['▁over', 'he', 'ard']

Also tried add_dummy_prefix_space=False, the output is still the same.

Expected behavior

The tokenize result should not add prefix space (SPIECE_UNDERLINE)

@scruel scruel changed the title add_prefix_space can be set for Llama tokenizer add_prefix_space won't be respected by Llama tokenizer Mar 13, 2024
@aoxolotl
Copy link

aoxolotl commented Mar 15, 2024

Hey, I took a peek under the hood and looks like setting add_prefix_true is only changing kwargs[slow]=True (in tokenization_llama_fast.py. The super().__init__() method should receive this parameter if set.
Passing this in seems to work in preliminary tests

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=False)
>>> tokenizer.tokenize('overheard')
['over', 'he', 'ard']

Mind if I take this up @ArthurZucker & @scruel?

Edit: For completeness, showing that behavior is unchanged when add_prefix_space=True

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=True)
>>> tokenizer.tokenize('overheard')
>>> ['\u2581over', 'he', 'ard']

@scruel
Copy link
Contributor Author

scruel commented Mar 16, 2024

You always can take by creating a PR.

@aoxolotl
Copy link

Thank you, made a pull request. This was happening in T5TokenizerFast as well.

@ArthurZucker
Copy link
Collaborator

Thanks I'll review asap!

@huggingface huggingface deleted a comment from github-actions bot Apr 15, 2024
@huggingface huggingface deleted a comment from github-actions bot May 10, 2024
@ArthurZucker
Copy link
Collaborator

closing as #28881 fixed it!

@psinger
Copy link

psinger commented May 21, 2024

@ArthurZucker are you sure this is fixed? I am still experiencing this in 4.41.0:
image

I can also still not see it being used here:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L153

@ArthurZucker
Copy link
Collaborator

You need to se from_slow=True to trigger conversion

@ArthurZucker
Copy link
Collaborator

It is used in convert_slow 😉

@psinger
Copy link

psinger commented May 22, 2024

This is very confusing and not transparent to the user at all.
If I just use the AutoTokenizer class with default settings I would expect this to work and not silently do nothing.
It should at least give a warning, or rather set the from_slow then automatically.

@ArthurZucker
Copy link
Collaborator

I agree with you, on main there is this:

        if add_prefix_space is not None:
            logger.warning_once(
                "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
            )
            kwargs["from_slow"] = True

which should give you a warning and automatically convert it

@ArthurZucker
Copy link
Collaborator

But it does not seem to be taken into account. @itazap would be nice if you can investigate and open a PR to make sure it forces from flow:

In [1]: from transformers import AutoTokenizer
tokenizer
In [2]: tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-2-7b-hf",add_prefix_space=False)
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565

In [3]: tokenizer.encode("Hey")
Out[3]: [1, 18637]

In [4]: tokenizer.tokenize("Hey")
Out[4]: ['▁Hey']

In [5]: tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-2-7b-hf",add_prefix_space=False, from_slow=True)

In [6]: tokenizer.tokenize("Hey")
Out[6]: ['H', 'ey']

In [7]: tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-2-7b-hf",add_prefix_space=False)

In [8]: tokenizer.tokenize("Hey")
Out[8]: ['▁Hey']

@psinger
Copy link

psinger commented May 22, 2024

^^ Thanks

Another thing I noted, is that if I specify from_slow in tokenizer_config.json then it is ignored. Is this expected behavior?

itazap pushed a commit that referenced this issue May 22, 2024
@ArthurZucker
Copy link
Collaborator

I think it should be taken into account!

@psinger
Copy link

psinger commented May 23, 2024

Apparently not, need to manually set it.

@ArthurZucker
Copy link
Collaborator

When I manually add it to tokenizer_config.json on main it works

itazap added a commit that referenced this issue May 24, 2024
* add prefix space ignored in llama #29625

* adding test with add_prefix_space=False

* ruff

---------

Co-authored-by: Ita Zaporozhets <[email protected]>
@psinger
Copy link

psinger commented Jun 5, 2024

I am still struggling to understand how this exactly works with a combination of all different settings for the tokenizer. I believe a tutorial / docs description would be very helpful there.

For example, this breaks:

tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/deepseek-coder-6.7b-base",
    add_prefix_space=False,
)

I assume because it forces from_slow=True but cannot find any SentencePiece model file. But why can't I use add_prefix_space with just a fast tokenizer?

@ArthurZucker

@ArthurZucker
Copy link
Collaborator

This will be supported very soon, @itazap is working on making all of this a lot simpler and clearer!
And yes, that is what's happening but shoul;d not, we should not need sentencepiece dependency to update prefix space. That's a mistake on my par sorry about it 😢

@psinger
Copy link

psinger commented Jun 6, 2024

Looking forward to it :) @itazap
Actually, would be great if one could pass add_prefix_space to the tokenize function itself, instead of needing to pass it when creating the tokenizer. Currently it is really unflexible if one wants to tokenize separate parts and then concatenate them afterwards without such a functionality.

Please tag me on any prs, happy to give feedback on this in general.

@ArthurZucker
Copy link
Collaborator

I kinda agree with you and will see what I can do on the tokenizers side, but this might need a lot of changes (supporting a new argument) and can already be done but with the attribute that needs to be set at each call, not super optimal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants