-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sharding when no device_map is passed #8531
Fix sharding when no device_map is passed #8531
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There is still a path where sharding is not handled. It happens when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!
very nice tests too:)
is it possible to explain device_map=None
in the doc string for device_map
too?
Done ! |
@@ -872,6 +872,39 @@ def test_model_parallelism(self): | |||
|
|||
@require_torch_gpu | |||
def test_sharded_checkpoints(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is already here:
def test_sharded_checkpoints(self): |
Is it different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
he renamed this test to test_sharded_checkpoints_device_map
because in that test it loads with device_map='auto'
flag; this is a new test testing default value for device_map
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I renamed the tests since it makes more sense this way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much, Marc. I think there's some confusion in the tests as they are existing in the main
already. Am I missing out on something?
Alright then! Let’s merge this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
* Fix sharding when no device_map is passed * style * add tests * align * add docstring * format --------- Co-authored-by: Sayak Paul <[email protected]>
* Fix sharding when no device_map is passed * style * add tests * align * add docstring * format --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
This PR fixes the loading for sharded checkpoint when no
device_map
is passed. Currently, the following doesn't work:You can have more details here.