-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WarmupDecayLR.params.total_num_steps - total or per gpu? #633
Comments
I find it easier to think of training step in terms of effective batch size. So while this implies world size, it also incorporates gradient accumulation steps. I guess another way is to say training steps == optimizer steps. Does this help? I will update doc shortly. |
I'm not sure it helps. While it can be overriden by users with
So the modified question with this context - do we need to divide this by the number of gpus or not? My guess is that since you said "Effective batch size" then yes it should be divided by n_gpus, correct? I don't think though we have that n_gpus/world_size at ds init time, do we? as it may change through args to the We are dealing with a chicken and egg problem here, as we need to know the world size to get this right, but it's not set up until ds has been initialized. |
Besides the chicken and egg problem, I am a bit confused about that code snippet for computing training steps. It seems to be independent of effective batch size, or am I missing something? Also is len(train_dataloader) per gpu or for the world size? |
You're correct, I haven't supplied enough information to allow you to support me. My apologies, @tjruwase. Here is the full code:
and the batch is but this gives me the full dataset, and not the exact slice derived from the distributed sampler, I had to remove it. The original had:
But I can't invoke the distributed sampler, since it's too early in the game, I get:
So it seems to be a design issue in DS. It should have run We actually don't want to do that calculation I pasted above, since it's too early and we have to re-do it again once dist is setup, but we can't see another way. Actually, this won't be a problem at all if we could |
When I launch the program with:
It does know the world size early on according to the logs:
but it won't I guess I could have hacked around this if I could get the world size by simply measuring the length of the whole dataset and dividing it by the world size, which would give a roughly close to correct number, but again I can't call |
Perhaps the correct solution is that there should be 2 init functions:
All that In order to remain back compatible - the newly proposed |
Would it be possible to look at this issue, please? I can't complete the integration without this issue resolved. Thank you. |
If you're needing the world size before the deepspeed.initialize call then yes I think that's perfect motivation for calling This call was introduced in #608 and is still pretty new. You are correct that if deepspeed.init_distributed is not called before deepspeed.initialize then we will call init_distributed on the users behalf to make sure the distributed backend is started properly. |
oh, you have recently added just that - fantastic - I will give it a try and report back. Thank you, @jeffra |
It worked beautifully and greatly simplified our code - thank you, @jeffra It'd be awesome to document it at https://www.deepspeed.ai/getting-started/#writing-deepspeed-models perhaps? I've proposed this doc to reflect this important function #644 |
We are having a bit of hard time getting
total_num_steps
to pass toWarmupDecayLR
at init time - it's a bit too early for the logic as these points are configured once ddp/ds has started - we found a workaround, but it doesn't take into the account the number of gpus.I wanted to check that
WarmupDecayLR.params.total_num_steps
expects the total for the whole world and not per gpu.Currently the doc says:
so it'd be awesome to disambiguate the context.
Thank you!
The text was updated successfully, but these errors were encountered: