-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLaMA3] 'add_bos_token=True, add_eos_token=True' seems not taking effect #30947
Comments
I'm having the same issue. Neither of these change the encodings: |
Hey! This is related to #30607, the tokenizer for Llama3 is a That's something which should be handle on the |
@ArthurZucker I think it's called bos = "<|begin_of_text|>"
eos = "<|end_of_text|>"
tokenizer._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[
(bos, tokenizer.bos_token_id),
(eos, tokenizer.eos_token_id),
],
),
]
) Now I'm worried that the padding tokens won't get added properly, but that's a different issue... |
Padding token is unrelated, it's added if you ask the tokenizer to pad the input! |
In case anyone else is blocked by this issue, I copied code from #31316 into a function which patches the tokenizer to support dynamically setting Running this script—from transformers import AutoTokenizer
model_id = "yujiepan/llama-3.1-tiny-random"
text = "a b"
print("Load plain tokenizer\n")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(" Default:", tokenizer(text)["input_ids"])
tokenizer.add_eos_token = True
print(" Add EOS:", tokenizer(text)["input_ids"])
print("\nLoad and patch tokenizer\n")
tokenizer2 = AutoTokenizer.from_pretrained(model_id)
force_support(tokenizer2)
tokenizer2.add_eos_token = True
print(" Add EOS:", tokenizer2(text)["input_ids"])
tokenizer2.add_eos_token = False
print("Don't add:", tokenizer2(text)["input_ids"]) —prints:
|
System Info
Platform = Windows
PyTorch = 2.3.0
Transformers = 4.41.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
All of the statements above produce
[128000, 6151, 11, 1268, 527, 499, 3432, 30]
Expected behavior
I think when using
tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=True, add_eos_token=True)
, we get[128000, 6151, 11, 1268, 527, 499, 3432, 30, 128001]
,when using
tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=False, add_eos_token=False)
, we get[6151, 11, 1268, 527, 499, 3432, 30]
,The text was updated successfully, but these errors were encountered: