-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add warning message for beta and gamma parameters #31654
Add warning message for beta and gamma parameters #31654
Conversation
src/transformers/modeling_utils.py
Outdated
if "beta" in loaded_keys or "gamma" in loaded_keys: | ||
logger.warning( | ||
f"Parameter names `gamma` or `beta` for {cls.__name__} will be renamed within the model. " | ||
f"Please use different names to suppress this warning." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't this is quite right, this assumes the weight is called "beta" in the state dict, but it could be called "layer.beta"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @OmarManzoor,
Thanks for addressing this! We want to make sure we catch any place where the renaming happens, so any place where if gamma in key
and if beta in key
are True (so key can be a longer string that contains beta or gamma). As you've added, this would be in _load_pretrained_model
but also in _load_state_dict_into_model
Hi @amyeroberts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @OmarManzoor, thanks for iterating on this!
Given the diff, I'm slightly confused, were there no warnings being triggered before? It seems like they were from the tests and logging messages
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally" | ||
model = TestModelGamma(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More importantly, we should check that the parameter is renamed as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this out and it seems that the parameter is not renamed at all. Basically when we load the model using from_pretrained it seems that the parameter is still present with the name gamma_param
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't rename the value in the model, but will rename the value in the state_dict, I believe. Could you dive into the loading logic and verify what's happening?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried updating the tests. Could you kindly have a look?
I basically removed the warning code that I added in the post init method. Should that be kept? |
@OmarManzoor Ah, OK. I think the diff was rendering funny on github. Should be OK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for adding and iterating on this!
Thank you. |
* Add warning message for and parameters * Fix when the warning is raised * Formatting changes * Improve testing and remove duplicated warning from _fix_key
* Add warning message for and parameters * Fix when the warning is raised * Formatting changes * Improve testing and remove duplicated warning from _fix_key
* Add warning message for and parameters * Fix when the warning is raised * Formatting changes * Improve testing and remove duplicated warning from _fix_key
Why have you added warnings only for the initialization process and not for renaming during loading as well? The model I'm using is timm's convnext (which is even the companion framework to transformers), which would have the parameter gamma. When loading he just tells me that I didn't successfully load the gamma function without telling me why, and I think the user should be informed when renaming the state_dict, otherwise it will cause unnecessary confusion. |
What does this PR do?
This adds a warning message to notify about the renaming of
gamma
andbeta
parameters during initialisation and also during loading.Fixes #29554
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@amyeroberts