-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct loading of models with shared tensors when using accelerator.load_state() #2875
Conversation
…lerator.load_state()
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just one question
models[i].load_state_dict(state_dict, **load_model_func_kwargs) | ||
model.load_state_dict(state_dict, **load_model_func_kwargs) | ||
logger.info("All model weights loaded successfully") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason for this change? I'd expect only the prior to be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the if statement, he's loading the safetensors model directly whereas before, we were only getting the state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_model does both: it loads the file and uses it to populate the state_dict. Previously, each branch of the if-condition only loaded the file and after the if-condition, the model would load the state dict. Since load_model does both, I indented the statement on line 204 to become part of the else-clause. This becomes clearer when you have a look at the complete surroundings of the changes instead of only the affected lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other change in this line (aside from the indent) namely using model
instead of models[i]
is mostly cosmetic. My linter was complaining that the enumerate
call defines model
but it's never used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change and spotting the issue ! Could you add a test with a model with tied weight ? You can use the following test for reference : test_save_load_model
Yes, I'll have a look into it. |
You can verify that the shared weights are implemented correctly by checking the output. safetensors warns you about that fact. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating and adding the tests @jkuntzer ! Could you do a final check and see if the test that you added fails when you remove the changes you did ?
tests/test_accelerator.py
Outdated
# need to add this for compliance with other methods | ||
self.weight = self.linear1.weight | ||
self.bias = self.linear1.bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need that ? where does it fail ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It used to fail previously. You're right. This part can be safely removed.
I was only expecting linear2.weight and linear2.bias to be missing. Maybe this is due to self.weight = self.linear1.weight
self.bias = self.linear1.bias |
Nice ! Could you just fix the quality issue ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
What does this PR do?
I would run into problems with PyTorch's load_state_dict complaining about missing keys. These keys belonged to shared tensors. These shared keys are intentionally omitted by the safetensors library. To load a model correctly, one has to use safetensor's load_model function instead of the default load_state_dict function (described here). This was previously not done when using the load_state function of the Accelerator.
Fixes # (issue)
I think this issue might be relevant as they also report problems when loading with accelerator.load_state.
#2155
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.