Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer.save_pretrained fails when add_special_tokens=True|False #28472

Closed
2 of 4 tasks
shuttie opened this issue Jan 12, 2024 · 4 comments
Closed
2 of 4 tasks

Tokenizer.save_pretrained fails when add_special_tokens=True|False #28472

shuttie opened this issue Jan 12, 2024 · 4 comments

Comments

@shuttie
Copy link

shuttie commented Jan 12, 2024

System Info

transformers-4.34
python-3.11

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", add_special_tokens=True)
tok.save_pretrained("out")

The snippet:

  • works well on add_special_tokens= being present, absent, True/False on 4.33 and below
  • works well when add_special_tokens= is not added to the list of tokenizer parameters on 4.34+
  • fails when add_special_tokens= is present in parameters (with both True/False values) on 4.34+ with the following error:
Traceback (most recent call last):
  File "/home/shutty/private/code/savepbug/test.py", line 4, in <module>
    tok.save_pretrained("tokenz")
  File "/home/shutty/private/code/savepbug/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2435, in save_pretrained
    out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
          ^^^^^^^^^^^
  File "/usr/lib/python3.11/json/encoder.py", line 202, in encode
    chunks = list(chunks)
             ^^^^^^^^^^^^
  File "/usr/lib/python3.11/json/encoder.py", line 432, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/usr/lib/python3.11/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.11/json/encoder.py", line 439, in _iterencode
    o = _default(o)
        ^^^^^^^^^^^
  File "/usr/lib/python3.11/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type method is not JSON serializable

The issue happens on any tokenizer, not only on LLama one. I can confirm it failing the same way on bert-base-uncased

If you go to the tokenization_utils_base and dump the tokenizer_config just before the json.dumps, you may see that add_special_tokens surprizingly became a method, and not a bool:

{'clean_up_tokenization_spaces': False, 'unk_token': '<unk>', 'bos_token': '<s>', 'eos_token': '</s>', 'add_bos_token': True, 
'add_eos_token': False, 'use_default_system_prompt': False, 'additional_special_tokens': [], 'legacy': True, 
'model_max_length': 1000000000000000019884624838656, 'pad_token': None, 'sp_model_kwargs': {}, 
'spaces_between_special_tokens': False, 
'add_special_tokens': <bound method SpecialTokensMixin.add_special_tokens of LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, 
padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}>, 'added_tokens_decoder': {0: {'content': '<unk>', 'single_word': False, 'lstrip': False, 'rstrip': False, 
'normalized': False, 'special': True}, 1: {'content': '<s>', 'single_word': False, 'lstrip': False, 'rstrip': False,
 'normalized': False, 'special': True}, 2: {'content': '</s>', 'single_word': False, 'lstrip': False, 'rstrip': False,
 'normalized': False, 'special': True}}, 'tokenizer_class': 'LlamaTokenizer'}

My feeling that the issue is related to the #23909 PR which refactored a lot of tokenizer internals, so in the current version:

  • add_special_tokens is a part of kwargs passed to the tokenizer
  • there is also a method SpecialTokensMixin.add_special_tokens having the same name
  • when everything is being joined together before json.dumps, the method is being serialized instead of the kwargs parameter.

Expected behavior

Not crashing with TypeError: Object of type method is not JSON serializable as in was pre #23909 in 4.33.

@ArthurZucker
Copy link
Collaborator

It is indeed related to that PR, but it is also related to the fact that add_special_tokens even if saved, is not used when doing an encode pass. Thus it's better to error out than save it IMO as it won't be checked when encoding. I'll have a look at the PR

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Vasishta
Copy link

Could you please let me know how to overcome this issue or if this is addressed in any of the latest releases of transformers library. This seems have further impact with save_checkpoints in the Trainer of the SFTTrainer and causing it to fail.

@ArthurZucker
Copy link
Collaborator

Hey a PR was open see #31233

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants