Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gpt] Gpt2 fix half precision causal mask #23256

Merged
merged 4 commits into from
May 11, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Applies a similar fix than #23136 but for GPT2.

To reproduce:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", load_in_8bit=True)
inputs = torch.LongTensor([[1, 1, 1], [1, 2, 1]]).to(0)

print(model(inputs))

The explanation is the same as the tagged PR:

When going for low_cpu_mem_usage each parameter is force-casted to the expected dtype, which is force-set to torch.float16 for 8bit models.

Therefore, for 8bit models (and also half-precision models) the causal mask is always force casted to float16 as it is part of the model's state dict, hence expected to be loaded from the Hub if the mask is available on the state dict.

The fix is to add persistant=False and add a field _keys_to_ignore_on_unexpected (for removing the warnings) to avoid loading that causal mask from the state dict and assign it to the buffer, and all causal masks that are saved as buffers should do the same to avoid unexpected behaviors.

Some users reported that they were also able to reproduce on PyTorch main branch but without load_in_8bit, I didn't managed to reproduce that way, I will have a deeper look

cc @amyeroberts

@younesbelkada younesbelkada changed the title [gpt] Gpt2 fix 8bit inference [gpt] Gpt2 fix half precision causal mask May 10, 2023
@younesbelkada younesbelkada requested a review from amyeroberts May 10, 2023 11:20
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 10, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing! I just have a question about _keys_to_ignore_on_load_missing for decision transformer

@@ -746,7 +747,8 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "decision_transformer"
main_input_name = "states"
supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias" be in _keys_to_ignore_on_load_missing?

Copy link
Contributor Author

@younesbelkada younesbelkada May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you are right it should probably be in keys_to_ignore_on_load_unexpected only; will modify that!

@younesbelkada younesbelkada requested a review from amyeroberts May 10, 2023 16:24
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for fixing :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for fixing :)

@younesbelkada younesbelkada merged commit ca26699 into huggingface:main May 11, 2023
@younesbelkada younesbelkada deleted the gpt2-fix-inferencnce branch May 11, 2023 07:32
sheonhan pushed a commit to sheonhan/transformers that referenced this pull request May 15, 2023
* fix gpt2 inference

* fixup

* no need to be in `_keys_to_ignore_on_load_missing`
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* fix gpt2 inference

* fixup

* no need to be in `_keys_to_ignore_on_load_missing`
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* fix gpt2 inference

* fixup

* no need to be in `_keys_to_ignore_on_load_missing`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants