Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

”never_split“ not working on BertTokenizer #23459

Closed
1 of 4 tasks
lllyyyqqq opened this issue May 19, 2023 · 35 comments · Fixed by #23909
Closed
1 of 4 tasks

”never_split“ not working on BertTokenizer #23459

lllyyyqqq opened this issue May 19, 2023 · 35 comments · Fixed by #23909
Labels
Core: Tokenization Internals of the library; Tokenization. wontfix

Comments

@lllyyyqqq
Copy link

System Info

transformers 4.28.1
python 3.8.13

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  • I load BertTokenizer using my own vocab.txt, and add '[outline]' into never_split, which is included in my vocab.txt. However, '[outline]' got splitted. Following is my code:

tokenizer = BertTokenizer.from_pretrained(pretrained_path,never_split=['[outline]']) input = "。[outline]" print(tokenizer.tokenize(input)) # ['。', '[', 'out', '##line', ']']

  • I also do:
    print(tokenizer.basic_tokenizer.tokenize(input)) #['。', '[', 'outline', ']']

Expected behavior

When I do:
tokenizer.tokenize("。[outline]")
Get the result as ['。', '[outline]'], the tokens in never_split don't be splited.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

@zspo
Copy link
Contributor

zspo commented May 23, 2023

The '[' or ']' in BertTokenizer is punctuation, it will be split at first. And the outline or [outline] is not in vocab, its will be set UNK. It doesn't seem to make sense anymore.
Look the code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/tokenization_bert.py#L446
image

@lllyyyqqq
Copy link
Author

The '[' or ']' in BertTokenizer is punctuation, it will be split at first. And the outline or [outline] is not in vocab, its will be set UNK. It doesn't seem to make sense anymore. Look the code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/tokenization_bert.py#L446 image

Thanks for replying. As stated before, I am using my own vocab, and ’[outline]‘ is in it,

tokenizer = BertTokenizer.from_pretrained(my_vocab_path, never_split='[outline]')
print(tokenizer.convert_tokens_to_ids('[outline]'))
print(tokenizer.convert_tokens_to_ids('。'))
print(tokenizer.tokenize('。[outline]'))

1684871975903

@ArthurZucker
Copy link
Collaborator

Hey, reading the doc for the BertTokenizer, you should be using the do_basic_tokenize=True argument, as mentioned here.

@lllyyyqqq
Copy link
Author

Hey, reading the doc for the BertTokenizer, you should be using the do_basic_tokenize=True argument, as mentioned here.

Your link is broken, it says '404 - page not found'?
Plus, do_basic_tokenize=True is default setting. Even if I add it intentionally, the result stays the same.

tokenizer = BertTokenizer.from_pretrained(my_vocab_path, never_split=['[outline]'], do_basic_tokenize=True)
print(tokenizer.tokenize('。[outline]')) # ['。', '[', 'out', '##line', ']']

Correct me if I do anything wrong.

@ArthurZucker
Copy link
Collaborator

Sorry, anyway the argument was set to True by default so that's not the problem.
Let's me investigate, in the mean time doing tokenizer.add_token("[outline]", special_token = True)" should (I think) prevent it from being split

@ArthurZucker
Copy link
Collaborator

( the doc mentions :

 never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.

@ArthurZucker
Copy link
Collaborator

The best solution is to add the token to the list of special tokens using the add_token method

@lllyyyqqq
Copy link
Author

Yeah, add it as special_token does take care of the splitting problem. But in the latter process, I will decode with argument skip_special_tokens=True. Then the token will be skipped, while I don't want it be. For now, I add it to the special token list, but I still suggest fixing the never_split argument.

@ArthurZucker
Copy link
Collaborator

Then that means that the token that you want to add is not special. I think that if you add it without the special_token set to True it should not be spilt no?

@lllyyyqqq
Copy link
Author

Without special_token set to True, it will be splitted.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 25, 2023

No it won't :

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=False)
tokenizer.add_tokens("[outline]")
tokenizer.added_tokens_encoder
>>>  {'[outline]': 30522}

tokenizer.encode("[outline]")
>>> [101, 30522, 102]
tokenizer.decode(tokenizer.encode("[outline]"))
>>> '[CLS] [outline] [SEP]'
print(tokenizer.tokenize(". [outline]"))
>>> ['.', '[outline]']

tokenizer.decode(tokenizer.encode(". [outline]"), skip_special_tokens=True)
>>> '. [outline]'

@lllyyyqqq
Copy link
Author

In your case, it won't. But I am using a different vocab.txt, it splits.

@lllyyyqqq
Copy link
Author

Seems like '[outline]' will not be added anyway, since it's already in the vocab.

@ArthurZucker
Copy link
Collaborator

I don't understand. You have a very specific usage, where you don't want to split [outline] that is already in your vocab.
The basic tokenizer works as expected: tokenizer.basic_tokenizer.tokenize("[outline]") will not split it.
When you are calling tokenize on the BertTokenizerClass the _tokenize function is then called, which relies on the all_special_ids. That means that the token should be added to both lists.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=False, never_split= ["outline"])
tokenizer.add_tokens("[outline]")

I am guessing that this should work

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 25, 2023

Edit: I manually added "[outline]" to my vocab and it worked for both the solution I gave you

@lllyyyqqq
Copy link
Author

Unfortunately, it still doesn't work on my vocab. I think it strictly related to the vocab. So far, only adding it to the special tokens works for me.
Also, shat I posted before is basic tokenizer split it, tokenizer.basic_tokenizer.tokenize("[outline]") splits it into '[', 'outline', ']'. The tokenizer then send the split tokens to do Wordpiece instead of fix it to the origin '[outline]'. I think that may be the reason.

@lllyyyqqq
Copy link
Author

vocab.txt
Here is my vocab, you can try on it.

@ArthurZucker
Copy link
Collaborator

I tried loading a tokenizer using your vocabulary and I cannot reproduce your issue.
Try downloading the latests transformer version!

@lllyyyqqq
Copy link
Author

Why......
I've updated transformers to 4.29.2, still the same result....
here is my code

tokenizer = BertTokenizer.from_pretrained('../base_model/vocab.txt', never_split= ["[outline]"])
tokenizer.add_tokens("[outline]")
print(tokenizer.tokenize("。[outline]"))
# ['。', '[', 'out', '##line', ']']

@ArthurZucker
Copy link
Collaborator

Can you try tokenizer = BertTokenizer.from_pretrained('../base_model', never_split= ["[outline]"])
Also I would suggest you create a colab , this will make sure that your cache is not messing with this.

@lllyyyqqq
Copy link
Author

Here is the Colab result:
1685020473827

@ArthurZucker
Copy link
Collaborator

can you share a link to the colab, I'll try to reproduce and modify a copy 😉

@ArthurZucker
Copy link
Collaborator

Also you did not add the token using add_token(..., special_token = False)

@ArthurZucker
Copy link
Collaborator

Another solution is to initialise the tokenizer using ...from_pretrained( path, additional_special_tokens = ["[outline]"])

@lllyyyqqq
Copy link
Author

https://colab.research.google.com/drive/1EStD5K_lQM0-PgMUQ8z273TzAUcgY2IV?usp=sharing
You need to roll down to the bottom to see the code, add_token(...) already added.

additional_special_tokens add [outline] into special tokens too, so it works fine. But it still meets the skip_special_token problem. Anyway, this issue is about 'never_split' argument not working, so let's focus on this.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 25, 2023

Thanks a lot.
Indeed the token is not added, because in the _add_token a check prevent it to be added if it is already in the vocab.
Workaround:

tokenizer.added_tokens_encoder.update({"[outline]":85})
tokenizer.added_tokens_decoder.update({85:"[outline]"})
tokenizer.unique_no_split_tokens = sorted(set(tokenizer.unique_no_split_tokens).union({"[outline]"})) 
tokenizer._create_trie(tokenizer.unique_no_split_tokens)

Its is not really elegant indeed. Also adding a token means that whether or not it is in the vocab, we want it to be in the added tokens, so I think it makes sense to add it, even if it exists. WDYT @Narsil
edit: I think it comes down to a choice, and both could have pos and cons.

@ArthurZucker
Copy link
Collaborator

About never split, the last commit is 4 years old, it has never been touch, and I'd rather we find a way to work around your problem using new code rather than changing legacy code!

@lllyyyqqq
Copy link
Author

Glad we are on the same page in the end.

@ArthurZucker ArthurZucker added the Core: Tokenization Internals of the library; Tokenization. label May 26, 2023
@ArthurZucker
Copy link
Collaborator

I am not entirely sure yet whether or not we will support this as the fast ones don't, and my few tests appear to show that it might not be optimal

@ArthurZucker
Copy link
Collaborator

For now closing as wontfix, if more people require such usage, will make it available.
TLDR: adding a tokens as AddedToken when it is already in the vocab.
Will not fix because:

  • fast does not support this and we want to keep fast stable
  • it's a breaking change
  • seems to be a specific usage, not the one intended for add_tokens

@Hir98
Copy link

Hir98 commented Oct 2, 2023

@lllyyyqqq @ArthurZucker
i am also facing similar kind of issue like i have product -> 'Blueberry'.
but when it tokenize it split into to token -> 'blue' and '##berry' .

so i don't want to split word into multiple token with Autotokenizer and Bertokenizer.

can you please help me how can i do this??

@ArthurZucker
Copy link
Collaborator

you should add Blueberry using tokenizer.add_tokens("Blueberry") this is now supported by the fast tokenizer

@Hir98
Copy link

Hir98 commented Oct 2, 2023

@ArthurZucker Is this supported in BertTokenizer ??

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Oct 16, 2023

Actually yes!
Previously I said the following:

  • fast does not support this and we want to keep fast stable
  • it's a breaking change
  • seems to be a specific usage, not the one intended for add_tokens

but this is now possible 😉 see the following:

>>> from transformers import BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> tokenizer.add_tokens("[outline]")
>>> print(tokenizer.tokenize("。[outline]"))
['。', '[outline]']

I tested with both fast and slow and it worked as expected 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization. wontfix
Projects
None yet
5 participants