Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect add_prefix_space option in LlamaTokenizerFast #29694

Closed
wants to merge 4 commits into from

Conversation

aoxolotl
Copy link

@aoxolotl aoxolotl commented Mar 17, 2024

What does this PR do?

Respect add_prefix_space option in Llama tokenizer (Fixes #29625)

The add_prefix_space option in Llama tokenizer was set but not passed to the super().__init__() method to PretrainedTokenizersFast. This resulted in SPIECE_UNDERLINE token being added even when add_prefix_space=False.

Minimal example

When add_prefix_space is False

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=False)
>>> tokenizer.tokenize('overheard')
['over', 'he', 'ard']

add_prefix_space is True

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=True)
>>> tokenizer.tokenize('overheard')
>>> ['\u2581over', 'he', 'ard']

Fixes #29625

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @scruel

@aoxolotl aoxolotl marked this pull request as draft March 17, 2024 00:26
@aoxolotl aoxolotl force-pushed the llama_add_prefix_space branch from 23763dc to 29ffaeb Compare March 18, 2024 14:04
@aoxolotl aoxolotl changed the title Respect add_prefix_space option in LlamaTokenizerFast and T5TokenizerFast Respect add_prefix_space option in LlamaTokenizerFast Mar 18, 2024
@aoxolotl aoxolotl marked this pull request as ready for review March 18, 2024 14:54
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you add a test to make sure this works? 🤗

@aoxolotl
Copy link
Author

Sure! Will update the pull request with the tests

@aoxolotl
Copy link
Author

aoxolotl commented Mar 24, 2024

Hey @ArthurZucker, I have a question about the tests. There already seems to be an add_prefix_space test present in test_tokenization_llama.py. However, this strangely passes even without the modifications above. The difference seems to come from using AutoTokenizer (as in the mentioned issue) vs using the LlamaTokenizer* classes directly. Is there a difference between the init routes we take in the above two scenarios?

Edit:
The difference actually comes from using hf-internal-testing/llama-tokenizer-non-normalized vs meta-llama/Llama-2-7b-hf. Examples:

>>> hf_tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer-non-normalized", add_prefix_space=False, legacy=False)
>>> hf_tokenizer.tokenize('overheard')
['over', 'he', 'ard']

>>> tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=False, legacy=False)
>>> tokenizer.tokenize('overheard')
['\u2581over', 'he', 'ard']

@ArthurZucker
Copy link
Collaborator

Actually I think we need to wait for #28881, to have a proper fix for Llama! Feel free to skip it. T5 should work!

@scruel
Copy link
Contributor

scruel commented Mar 25, 2024

Won't you see my review?
image

@ArthurZucker
Copy link
Collaborator

pending means it's not been submitted!

@scruel
Copy link
Contributor

scruel commented Mar 25, 2024

pending means it's not been submitted!

Oh, I see, it this HF repos limited review action by others?

@ArthurZucker
Copy link
Collaborator

No, pressing the submit review button should be enough

@@ -737,6 +737,7 @@ def as_tensor(value, dtype=None):

def is_tensor(obj):
return isinstance(obj, mx.array)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind remove this empty line?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, updating in latest commit

@scruel
Copy link
Contributor

scruel commented Mar 25, 2024

No, pressing the submit review button should be enough

It works! Thanks!

@aoxolotl
Copy link
Author

@scruel Removed in latest commit!
@ArthurZucker I will make a separate PR for T5 tokenizer which adds any necessary tests. Leaving this PR for LlamaTokenizer specific changes

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

#28881 should actually fix this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add_prefix_space won't be respected by Llama tokenizer
4 participants