-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding new tokens to various models changes tokenization of adjacent elements in strings #14770
Comments
Environment
Hello ! I have noticed the same with transformers (v3.3.1) with the BartTokenizer. Tokenization behavior on some existing words changes after adding new tokens, and the Ġ prefix disappears as well. |
Thank you very much for the detailed issue, unfortunately it seems to us that there is no simple way to add tokens in the way you describe. Currently the added tokens are not added to the vocabulary of the tokenization model - here WordPiece - but are preserved from the beginning of the tokenization - no matter which tokenization model is used afterwards. To put it simply, if you added the However, if you see how a easy solution could be implemented, we would be happy to discuss it! |
Hi Putting aside a solution, maybe a warning message should be added in Example code: tokenizer = AutoTokenizer.from_pretrained(...)
new_tokens = [...]
vocab_before_add = list(tokenizer.vocab)
vocab_tokenization_before_add = [tuple(tokenizer.tokenize(w)) for w in vocab_before_add]
tokenizer.add_tokens(new_tokens)
vocab_tokenization_after_add = (tuple(tokenizer.tokenize(w)) for w in vocab_before_add)
in_vocab_tokens_changed = [
(w, before, after)
for w, before, after in zip(vocab_before_add, vocab_tokenization_before_add, vocab_tokenization_after_add)
if before != after
] Thanks! |
Hey! Thanks for reporting. I'll have a look, when I can. |
Regarding the original issue as well as the second issue, it appears that a specific parameter exist to prevent the tokenizer from matching the new token in the middle of words. Also, regarding the spaces before and after, |
Hi @ArthurZucker, I'm using the latest version of transformers (4.37.2), but still having the same odd behaviour which @mawilson1234 described in the initial comment. I see that the PR23909 is merged in master and it should already solve this. Do you have any thoughts ? thanks @mawilson1234 did you manage to make it work ? thanks also :) Right now I'm reading another issue: google-research/bert#396. The solution would be to change the [unused#] tokens from vocabulary. Made a little experiment, changed few [unused] tokens and it seems to work correctly. |
Do you have a reproducer @tlapusan ? |
sure, just created a colab notebook : https://colab.research.google.com/drive/1fnn9gZgjI-UJdkfp5tAPwF06NGhs4pGc?usp=sharing It's basically the same code which @mawilson1234 wrote in his initial comment. |
Seems to me that the slow and fast tokenizer do not give the same behaviour and will not add |
(just for bert as it is a special case. tokenize can be change a bit to add |
Could you please add more details here ? Is it somehow related to the pre-tokenization step which @SaulLu also mentioned in a previous comment ? The pre-tokenization looks the same (with and without the new 'mynewword' added to the vocabulary). It seems for me that the 'mynewword' will be taken in account by the wordpiece tokenizer.
|
The issue is not with the |
Environment info
transformers
version: 4.13.0Who can help
@LysandreJik @SaulLu
Information
Models I am using: DistilBERT, BERT, RoBERTa
The problem arises when using:
The tasks I am working on is:
To reproduce
When adding a new token to various models (so far found with DistilBERT, BERT, and RoBERTa), adding a new token using the
add_tokens
function changes how adjacent parts of the string are tokenized in subtle ways (for DistilBERT and BERT, this might depend ondo_basic_tokenize
being set toFalse
when creating the tokenizer, at least in the examples I've found). (This might be related to the issue reported in #11531, but that one specifically mentions T5.) See the code below for details.This doesn't seem like intended behavior based on what I can tell from looking at the documentation, but it's possible I'm misunderstanding something about the right way to add new tokens to produce the behavior I'd like. (Currently, to get the expected behavior, I've had to manually modify the vocab (+ merges file for RoBERTa), using additional scripting, and load the tokenizer from the modified files. If it'd be of use, I could post the code for that workaround here, but I've left it out for now since it's a bit long and may not be relevant.)
Steps to reproduce the behavior:
(Distil)BERT:
RoBERTa:
Expected behavior
Adding a token to a tokenizer should not affect tokenization of adjacent elements (when these are not part of the added token).
The text was updated successfully, but these errors were encountered: