-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix step shifting when accumulate gradient #33673
Fix step shifting when accumulate gradient #33673
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed issue @kibitzing! Really nice job explaining what's happening! This looks like the right fix for gradient accumulation. In the past, we had the same behavior as this PR but without updating at the end of the epoch. See this PR for more information. Since we do update after the end of each epoch, it make sense to not use total_batched_samples
anymore. I see that in accelerate, we only turn the gradient sync when we are at the end of the dataloader or step+1 % args.gradient_accumulation_steps == 0
. Potentially, we can even remove the condition that we placed in transformers.
From
if (
(step + 1) % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)
to simply
if self.accelerator.sync_gradients:
Can you check if we have the same behavior ? cc @muellerz
Hello @SunMarc , Thank you for taking the time to review my PR! I appreciate your suggestion, it's a great idea and would indeed simplify the code. However, after investigating and experimenting with this, I found that
While it's been interesting to dig deeper, I believe this issue might be outside the scope of the current PR. Hence, I suggest we stick with the previous approach for now. Additionally, during my testing, I found a bug in the condition + It is not modified now but to use transformers/src/transformers/trainer.py Line 4820 in f2c388e
_do_sync function check this as well.
|
Regarding the failing test (run_swag_no_trainer.py), I checked that it is not using the Trainer, which I have fixed. Do you have any suggestions or comments on it? |
After reading the code a bit more, I don't think we are performing a gradient update at the end of each epoch no (see your example) ? I said that because as you saw the condition using
Thanks for exploring ! Could you share a minimal reproducer ? This might indeed require a fix on accelerate side.
This is probably a flaky test ! Do you have any suggestions or comments on it? |
Thank you for reply @SunMarc,
Sure, I will! |
I mean that the original issue you had (too many updates) might not exist |
Yes, you are right, that issue does not exist. |
Not an issue ! I was confused also. Still, I will discuss with @muellerzr if it makes sense to switch to back to your idea with the update at each last step, just like how it is coded in accelerate. cc@mueller I'll keep you updated. Now, the issue is why self.gradient_state.end_of_dataloader doesn't work as expected. Feel free to open an issue on accelerate library with the reproducer ! Thanks a lot ! |
e4f9d89
to
e4cc360
Compare
Okay, I will create a new issue regarding the |
Hello @SunMarc, I revisited the issue I previously reported regarding To explain the flow in more detail:
Therefore, there is no issue with the Thus, as you suggested, using I’m going to change the current condition to this one and push the update. |
Hello, I have updated the if condition as per @SunMarc's suggestion. |
@SunMarc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I feel this is a great simplification all around :)
(Last bit is resolving the conflicts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix the merge comflits. We did a lot of modification wrt to grad acc and it makes sense to have this now ! Feel free to ask us any question you have !
src/transformers/trainer.py
Outdated
is_last_step_and_steps_less_than_grad_acc = ( | ||
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch | ||
) | ||
|
||
if ( | ||
total_batched_samples % args.gradient_accumulation_steps == 0 | ||
or | ||
# last step in epoch but step is always smaller than gradient_accumulation_steps | ||
is_last_step_and_steps_less_than_grad_acc | ||
): | ||
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered | ||
# in accelerate. So, explicitly enable sync gradients to True in that case. | ||
if is_last_step_and_steps_less_than_grad_acc: | ||
self.accelerator.gradient_state._set_sync_gradients(True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decided to perform the grad acc at the end of the dataloader but in trainer, we will set self.accelerator.gradient_state._set_sync_gradients by ourselves and not rely on the values set by accelerate.accumulate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then I'll stick with the do_sync_step
in the main branch code, and just replace the total_batched_samples
with step
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's right !
src/transformers/trainer.py
Outdated
@@ -4786,8 +4769,6 @@ def create_accelerator_and_postprocess(self): | |||
# take the gradient_accumulation_steps setting from TrainingArguments. | |||
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps | |||
|
|||
grad_acc_kwargs["sync_with_dataloader"] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, since we're setting self.accelerator.gradient_state._set_sync_gradients
by ourselves in the trainer, would it be safer to keep it set to False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's keep it set to False !
Hello, @SunMarc @muellerzr Here is a summary of the new commits:
If there are any issues or additional modifications needed, please let me know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot better ! Thanks for your patience @kibitzing. If you are happy with the changes, feel free to merge the PR @muellerzr
failing tests come from main, we are in limbo 🫠 |
cc @ydshieh |
Hello @muellerzr, I merged the main branch and it looks like it passed the tests! 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
* replace total_batched_samples with step while counting grad accum step * remove unused variable * simplify condition for update step * fix format by ruff * simplify update step condition using accelerator.sync_gradients * simplify update condition using do_sync_step * remove print for test --------- Co-authored-by: Zach Mueller <[email protected]>
What does this PR do?
Fixes #33671
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@muellerzr and @SunMarc