Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Llama FastTokenizer support. #22264

Merged
merged 4 commits into from
Apr 6, 2023

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Mar 20, 2023

How to test:

#! pip install -e https://github.com/huggingface/tokenizers@byte_fallback#egg=tokenizers

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@Narsil Narsil changed the title Adding Llama FastTokenizer support. [WIP] Adding Llama FastTokenizer support. Mar 20, 2023
@Narsil Narsil marked this pull request as draft March 20, 2023 11:17
@sgugger
Copy link
Collaborator

sgugger commented Mar 20, 2023

Thanks for the ping. We'll need the actual fast tokenizer file to merge this though 😅

@Narsil
Copy link
Contributor Author

Narsil commented Mar 20, 2023

True, I uncovered more issues around multiple space handling, I'm nailing down on the pre_tokenizer combo for it.

@Narsil
Copy link
Contributor Author

Narsil commented Mar 21, 2023

More troublesome than anticipated.

When encoding " Hello" from a pure BPE perspectivve, tokenizers does [259, 10994] (" " + Hello)
whereas spm does [29871, 15043] (" " + " Hello") which from a pure ids & merges perspectives seems worse.

I though of fixing that using a pre_tokenizer that splits words onto their own.

However on encoding " ird" this time spm DOES do [259, 1823].
Seems this is where the score comes into play.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! 😄 I have to take care of a few things on the slow side and should be done

src/transformers/convert_slow_tokenizer.py Show resolved Hide resolved
# Options to consider in order to implement:
# - Change `add_prefix_space` to ternary, False, True, "force", "force" being
# the new version which always prefixes
# - Add a new extra pre_tokenizer which doesn't pretokenize but does this job.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Add a new extra pre_tokenizer which doesn't pretokenize but does this job.
    Since this was added we don't need that comment anymore no

Comment on lines 123 to 127
# These are known differences
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا.")
# XXX Extra space
# self.assertEqual(rust_tokenizer._tokenizer.decode([30112, 869]), "ا .")
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should go away with cleanup_tokenization_space here #22341

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(flagging to take care of this test if this is merged first)

tests/models/llama/test_tokenization_llama.py Show resolved Hide resolved
@OlivierDehaene
Copy link
Member

What is the status of this PR?

@Narsil Narsil force-pushed the fast_llama_tokenizer branch from 3f20703 to 3819151 Compare April 4, 2023 10:04
@Narsil Narsil changed the title [WIP] Adding Llama FastTokenizer support. Adding Llama FastTokenizer support. Apr 4, 2023
@Narsil Narsil marked this pull request as ready for review April 4, 2023 10:04
setup.py Outdated
@@ -176,7 +176,7 @@
"tf2onnx",
"timeout-decorator",
"timm",
"tokenizers>=0.11.1,!=0.11.3,<0.14",
"tokenizers==0.13.3rc1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to be change to a minimum pin.

Comment on lines 55 to 62
piece_score = vocab_scores.get(merge, None)
if piece_score:
merges += [(piece_l, piece_r, piece_score)]
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be in its own PR with a flag for breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not breaking anymore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still has a very strong potential to be breaking as it touches code functionality, and it will be easier to isolate it in a git bisect if it goes in its own PR. So I insist.
You can just reopen the PR you closed and amend it with those changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1 to 15
from ...tokenization_utils_fast import PreTrainedTokenizerFast


class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
"""

def __init__(
self,
*args,
clean_up_tokenization_spaces=False,
**kwargs,
):
super().__init__(*args, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs copyright doc etc.

Comment on lines +320 to +321
# This is excruciatingly slow since it has to recreate the entire merge
# list from the original vocabulary in spm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need a smaller tokenizer then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we could create a dummy one, but we're never going to be sure to have every argument down the same.

This is supposed to be a sanity check that conversion = some static reference value. I'm not sure checking all the time this conversion is necessary, but it's nice test to have if regressions ever happen.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for isolating the change in the conversion in #22582, this PR will need to be rebased after it's merged.

Still one comment on building a smaller tokenizer for the tests if possible and fletching out the fast tokenizer module.

@@ -0,0 +1,19 @@
from ...tokenization_utils_fast import PreTrainedTokenizerFast
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still missing copyright here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc could be expanded here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded.

Comment on lines 15 to 44
*args,
clean_up_tokenization_spaces=False,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually show the args and at least the special tokens kwargs in the signature of those.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did add the special tokens.

I have no idea what the args are supposed to be.PreTrainedTokenizerFast is also using *args.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what XLNet does:

class XLNetTokenizerFast(PreTrainedTokenizerFast):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loosely copied from there.
I removed the arguments we're not using and added clean_up_tokenization_spaces

@Narsil
Copy link
Contributor Author

Narsil commented Apr 5, 2023

For the doc builder, we're going to need an update on the docker image so that it pulls 0.13.3 to generate the doc.

Narsil added 2 commits April 5, 2023 16:07
- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.
@Narsil Narsil force-pushed the fast_llama_tokenizer branch from 7e257e5 to ea90a99 Compare April 5, 2023 14:15
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go once all tests pass. Thanks!

@Narsil Narsil merged commit 1670be4 into huggingface:main Apr 6, 2023
@Narsil Narsil deleted the fast_llama_tokenizer branch April 6, 2023 07:53
@stefan-it
Copy link
Collaborator

stefan-it commented Apr 6, 2023

Hi @Narsil ,

the warning.warn to raise RuntimeError change in src/transformers/convert_slow_tokenizer.py breaks a lot of things: I wanted to fine-tune a mT5 model and it is now no longer possible (I'm using the PyTorch example from documentation.)

How is it possible to rubustify it -> also DeBERTa v3 has byte fallback vocab (but I didn't test it yet) 🤔

@Narsil
Copy link
Contributor Author

Narsil commented Apr 6, 2023

Hi @Narsil ,

the warning.warn to raise RuntimeError change in src/transformers/convert_slow_tokenizer.py breaks a lot of things: I wanted to fine-tune a mT5 model and it is now no longer possible (I'm using the PyTorch example from

How is it possible to rubustify it -> also DeBERTa v3 has byte fallback vocab (but I didn't test it yet) thinking

First of all we could revert by all means, but since now tokenizers has ByteFallback we could make it 1-1 for those, that was the idea behind upping to an error.

It's a relatively sizeable issue if there are models deployed out there which have inconsistent behavior regarding this though (slow using byte fallback, fast not using it). I'm not sure why it was a warning in the first place.

DeBERTa v3

Let's have a look too.

As a user, what's your opinion here, should we just fix the various conversion scripts, or would you rather keep the warning with the previous pitfalls ?

@Narsil
Copy link
Contributor Author

Narsil commented Apr 6, 2023

Both are using Unigram with ByteFallback which isn't supported yet.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 7, 2023

@Narsil After this commit AutoTokenizer.from_pretrained is extremely slow, spending time in convert_slow_tokenizer.py at every call. Is it expected? Or I am doing something wrong?

@Narsil
Copy link
Contributor Author

Narsil commented Apr 7, 2023

Which repo are you using? We need to create the fast files on the repo.

Converting from slow is super slow and there's nothing to be done about it (tokenizers needs to recreate a structure by doing O(n2) search over the vocab because spm does not store this information.

@Narsil
Copy link
Contributor Author

Narsil commented Apr 7, 2023

@ArthurZucker

@fxmarty
Copy link
Contributor

fxmarty commented Apr 7, 2023

I see thanks!

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Adding Llama FastTokenizer support.

- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants