Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve PreTrainedTokenizerFast loading time when there are many added tokens #31404

Merged
merged 3 commits into from
Jun 18, 2024

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jun 13, 2024

What does this PR do?

The condition check

token not in self.added_tokens_decoder

is very slow, especially when there are more added tokens.

Loading time (see code snippet below): 2048 added tokens

before this PR:

3.909789 seconds

after this PR:

0.23008 seconds

from transformers import AutoProcessor, XLMRobertaTokenizerFast, XLMRobertaTokenizer, AutoTokenizer, PreTrainedTokenizerFast

import datetime

ckpt = "ydshieh/dummy_tok"
s = datetime.datetime.now()
p = XLMRobertaTokenizerFast.from_pretrained("my_p")
e = datetime.datetime.now()
print((e-s).total_seconds())


@ydshieh ydshieh requested a review from ArthurZucker June 13, 2024 12:20
Comment on lines 175 to 176
# Use hash to speed up the very slow operation `token not in added_tokens_decoder`.
added_tokens_decoder_hash = {hash(token) for token in self.added_tokens_decoder}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a very very tiny chance of hash collision. Do we want to address that possibility?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no 😉

Comment on lines 175 to 176
# Use hash to speed up the very slow operation `token not in added_tokens_decoder`.
added_tokens_decoder_hash = {hash(token) for token in self.added_tokens_decoder}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no 😉

@@ -172,10 +172,12 @@ def __init__(self, *args, **kwargs):
# allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens
# uses the information stored in `added_tokens_decoder`.
# this is costly for fast tokenizers as we re-compute the regex again. But not all tokens are added tokens
# Use hash to speed up the very slow operation `token not in added_tokens_decoder`.
added_tokens_decoder_hash = {hash(token) for token in self.added_tokens_decoder}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
added_tokens_decoder_hash = {hash(token) for token in self.added_tokens_decoder}
added_tokens_decoder_hash = {hash(token.__str__()) for token in self.added_tokens_decoder}

would that be even faster? hash the string rep?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if we implement this at the class level I am wondering if it is not faster? WOuld be computed when you init the object most probably

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's on par :-) OK for me to both, but if using str, I would just do it str(token)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you need to make sure normalized, left, right and single word is not the same

Copy link
Collaborator Author

@ydshieh ydshieh Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str and __str__ will give the same thing. Do you mean to use __repr__?

That will give

'AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)'

while str and __str__
will give, for example,

> </s>

note str will call __str__ under the hood.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry repr

@ArthurZucker
Copy link
Collaborator

Super nice BTW

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh ydshieh requested a review from ArthurZucker June 13, 2024 12:55
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it on par to have the init at the AddedTokens class level? (make sure to not import it from tokenizers)

@@ -172,10 +172,12 @@ def __init__(self, *args, **kwargs):
# allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens
# uses the information stored in `added_tokens_decoder`.
# this is costly for fast tokenizers as we re-compute the regex again. But not all tokens are added tokens
# Use hash to speed up the very slow operation `token not in added_tokens_decoder`.
added_tokens_decoder_hash = {hash(token) for token in self.added_tokens_decoder}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you need to make sure normalized, left, right and single word is not the same

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 13, 2024

is it on par to have the init at the AddedTokens class level? (make sure to not import it from tokenizers)

sorry, I don't understand this. Could you elaborate a bit more?

@ArthurZucker
Copy link
Collaborator

What I mean is implement the hash for this class

@@ -172,10 +172,12 @@ def __init__(self, *args, **kwargs):
# allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens
# uses the information stored in `added_tokens_decoder`.
# this is costly for fast tokenizers as we re-compute the regex again. But not all tokens are added tokens
# Use hash to speed up the very slow operation `token not in added_tokens_decoder`.
added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repr used

@ydshieh ydshieh requested a review from ArthurZucker June 13, 2024 15:56
@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 14, 2024

@ArthurZucker Everything is addressed.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! Thanks for this update

@ArthurZucker
Copy link
Collaborator

huggingface/tokenizers#1521 a related PR

@ydshieh ydshieh merged commit 1c7c34b into main Jun 18, 2024
21 checks passed
@ydshieh ydshieh deleted the speedy_fast_token_loading branch June 18, 2024 13:20
itazap pushed a commit that referenced this pull request Jun 18, 2024
…ded tokens (#31404)

* use hash

* use hash

* update

---------

Co-authored-by: ydshieh <[email protected]>
itazap pushed a commit that referenced this pull request Jun 20, 2024
…ded tokens (#31404)

* use hash

* use hash

* update

---------

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants