Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5Tokenizer Fast and Slow give different results with AddedTokens #16334

Closed
patrickvonplaten opened this issue Mar 22, 2022 · 5 comments · Fixed by #23909
Closed

T5Tokenizer Fast and Slow give different results with AddedTokens #16334

patrickvonplaten opened this issue Mar 22, 2022 · 5 comments · Fixed by #23909

Comments

@patrickvonplaten
Copy link
Contributor

When adding a new token to T5TokenizerFast and/or T5Tokenizer, we get different results for the tokenizers which is unexpected.

E.g. running the following code:

from transformers import AutoTokenizer, AddedToken

tok = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
tok_fast = AutoTokenizer.from_pretrained("t5-small", use_fast=True)

tok.add_tokens("$$$")
tok_fast.add_tokens(AddedToken("$$$", lstrip=False))

prompt = "Hello what is going on $$$ no ? We should"

print("Slow")
print(tok.decode(tok(prompt).input_ids))

print("Fast")
print(tok_fast.decode(tok_fast(prompt).input_ids))

yields different results for each tokenizer

Slow
Hello what is going on $$$ no? We should</s>
Fast
Hello what is going on$$$ no? We should</s>

Environment info

  • transformers version: 4.18.0.dev0
  • Platform: Linux-5.15.15-76051515-generic-x86_64-with-glibc2.34
  • Python version: 3.9.7
  • Huggingface_hub version: 0.4.0.dev0
  • PyTorch version (GPU?): 1.10.2+cu102 (True)
  • Tensorflow version (GPU?): 2.8.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.4.0 (cpu)
  • Jax version: 0.3.1
  • JaxLib version: 0.3.0
@patrickvonplaten
Copy link
Contributor Author

cc @Narsil @SaulLu

@Narsil
Copy link
Contributor

Narsil commented Mar 23, 2022

Hi, The behavior can be explained by the fact that the encode, splits on whitespace and ignores them,
then the decoder uses Metaspace (which is for the spm behavior) which does not prefix things with spaces even on the added token. The spaces are supposed to already be contained within the tokens themselves.

We could have parity on this at least for sure !

But I am not sure who is right in that case, both decoded values look OK to me. The proposed AddedToken contains no information about the spaces so it's ok to no place one back by default (it would break things when added tokens are specifically intended for stuff not containing spaces).
In that particular instance, because we're coming from a sentence with a space, ofc it makes more sense to put one back to recover the original string. But decode[999, 998] with 999="$(" and 998=")$" It's unclear to me if a user wants "$( )$" or "$()$" when decoded. (Just trying to take an plausible example where the answer is unclear.)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@wise-east
Copy link

should this be reopened if it's not resolved yet?

@SaulLu SaulLu reopened this Aug 8, 2022
@github-actions
Copy link

github-actions bot commented Sep 1, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants