Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix step shifting when accumulate gradient #33673

Merged

Conversation

kibitzing
Copy link
Contributor

@kibitzing kibitzing commented Sep 24, 2024

What does this PR do?

Fixes #33671

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr and @SunMarc

@SunMarc SunMarc self-requested a review September 26, 2024 01:38
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed issue @kibitzing! Really nice job explaining what's happening! This looks like the right fix for gradient accumulation. In the past, we had the same behavior as this PR but without updating at the end of the epoch. See this PR for more information. Since we do update after the end of each epoch, it make sense to not use total_batched_samples anymore. I see that in accelerate, we only turn the gradient sync when we are at the end of the dataloader or step+1 % args.gradient_accumulation_steps == 0. Potentially, we can even remove the condition that we placed in transformers.
From

                if (
                    (step + 1) % args.gradient_accumulation_steps == 0
                    or
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    is_last_step_and_steps_less_than_grad_acc
                ):
                    # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
                    # in accelerate. So, explicitly enable sync gradients to True in that case.
                    if is_last_step_and_steps_less_than_grad_acc:
                        self.accelerator.gradient_state._set_sync_gradients(True)

to simply

if self.accelerator.sync_gradients:

Can you check if we have the same behavior ? cc @muellerz

@kibitzing
Copy link
Contributor Author

kibitzing commented Sep 26, 2024

Hello @SunMarc ,

Thank you for taking the time to review my PR! I appreciate your suggestion, it's a great idea and would indeed simplify the code.

However, after investigating and experimenting with this, I found that self.gradient_state.end_of_dataloader is always False in my case, preventing the step of self.accelerator from being reset.
This happens because:

  • self.gradient_state.end_of_dataloader depends on self.gradient_state.active_dataloader
  • and active_dataloader is generally set to None, unless explicitly configured otherwise.

While it's been interesting to dig deeper, I believe this issue might be outside the scope of the current PR. Hence, I suggest we stick with the previous approach for now.

Additionally, during my testing, I found a bug in the condition is_last_step_and_steps_less_than_grad_acc. It should use or instead of and for the expected behavior. If and is used, the last steps are always dropped if args.gradient_accumulation_steps is larger than step_in_epoch, which is highly likely to happen. I just removed the comparing step_in_epoch and grad_accum_step part because updating at the last step covers that case as well. I’ve fixed this in the PR because I believe it is directly related to gradient accumulation.


+ It is not modified now but to use self.accelerator.sync_gradients, we should also remove

grad_acc_kwargs["sync_with_dataloader"] = False
because _do_sync function check this as well.

@kibitzing
Copy link
Contributor Author

kibitzing commented Sep 27, 2024

Regarding the failing test (run_swag_no_trainer.py), I checked that it is not using the Trainer, which I have fixed.

Do you have any suggestions or comments on it?

@SunMarc
Copy link
Member

SunMarc commented Sep 27, 2024

After reading the code a bit more, I don't think we are performing a gradient update at the end of each epoch no (see your example) ? I said that because as you saw the condition using and in is_last_step_and_steps_less_than_grad_acc + we have set grad_acc_kwargs["sync_with_dataloader"] = False. So, I feel like we are doing the right number of update across epoch. The only issue if that we indeed have an issue since there is no N sub-steps between the on_step_begin and on_step_end callbacks.

Epoch 1

Steps 1, 2, 3, 4 (update parameter because (4 % 4 = 0)
Steps 5, 6, 7 (update because it's the last step)
Epoch 2

Step 8 (update parameter because 8 % 4 =0)
Steps 9, 10, 11, 12 (update parameter because 12 % 4 = 0)
Steps 13, 14 (update because it's the last step)
Epoch 3

Step 15, 16 (update parameter because 16 % 4 = 0)
Steps 17, 18, 19, 20 (update parameter because 16 % 4 = 0)
Steps 21 (update because it's the last step)

However, after investigating and experimenting with this, I found that self.gradient_state.end_of_dataloader is always False in my case, preventing the step of self.accelerator from being reset.
This happens because:

self.gradient_state.end_of_dataloader depends on self.gradient_state.active_dataloader
end_of_dataloader depends on in_dataloader
in_dataloader depends on active_dataloader
and active_dataloader is generally set to None, unless explicitly configured otherwise.

Thanks for exploring ! Could you share a minimal reproducer ? This might indeed require a fix on accelerate side.

Regarding the failing test (run_swag_no_trainer.py), I checked that it is not using the Trainer, which I have fixed.

This is probably a flaky test !

Do you have any suggestions or comments on it?

@kibitzing
Copy link
Contributor Author

kibitzing commented Sep 27, 2024

Thank you for reply @SunMarc,
I may have misinterpreted the code initially. I agree that there's no issue with the update logic. I will go ahead and revert my changes to keep the condition is_last_step_and_steps_less_than_grad_acc.

Thanks for exploring ! Could you share a minimal reproducer ? This might indeed require a fix on accelerate side.

Sure, I will!

@SunMarc
Copy link
Member

SunMarc commented Sep 27, 2024

Thank you for reply @SunMarc,
I may have misinterpreted the code initially. I agree that there's no issue with the update logic. I will go ahead and revert my changes to keep the condition is_last_step_and_steps_less_than_grad_acc.

I mean that the original issue you had (too many updates) might not exist

@kibitzing
Copy link
Contributor Author

Yes, you are right, that issue does not exist.
I thought it always update at last, but it doesn't because it is generally blocked by "steps_less_than_than_grad_acc" condition.
Sorry for the confusion.

@SunMarc
Copy link
Member

SunMarc commented Sep 27, 2024

Not an issue ! I was confused also. Still, I will discuss with @muellerzr if it makes sense to switch to back to your idea with the update at each last step, just like how it is coded in accelerate. cc@mueller I'll keep you updated.

Now, the issue is why self.gradient_state.end_of_dataloader doesn't work as expected. Feel free to open an issue on accelerate library with the reproducer ! Thanks a lot !

@kibitzing kibitzing force-pushed the fix-step-shifting-when-accum-grad branch from e4f9d89 to e4cc360 Compare September 27, 2024 14:34
@kibitzing
Copy link
Contributor Author

Okay, I will create a new issue regarding the self.gradient_state.end_of_dataloader and wait for updates on this PR.
Thanks!

@kibitzing
Copy link
Contributor Author

Hello @SunMarc,

I revisited the issue I previously reported regarding accelerate and, embarrassingly, it turned out to be a problem with my own codebase. I was using accelerate together with a custom Trainer but had overridden get_train_dataloader without calling accelerate.prepare(dataloader).

To explain the flow in more detail:

  1. active_dataloader is set to None at first.
  2. When prepare(dataloader) is called, it goes into prepare_data_loader, where it returns either a DataLoaderShard or DataLoaderShard that inherits from DataLoaderStateMixin as a new dataloader.
  3. These prepared dataloaders add an active_dataloader at every begin().
  • Without calling prepare, it is clear that the GradientState doesn't have an active dataloader

Therefore, there is no issue with the accelerate code itself. Everything works correctly when prepare is called, and I confirmed that at the last batch, self.gradient_state.end_of_dataloader=True is set as expected. I also checked this with the pytorch run_glue.py example.

Thus, as you suggested, using if self.accelerator.sync_gradients: is perfectly fine and results in simpler code.

I’m going to change the current condition to this one and push the update.

@kibitzing
Copy link
Contributor Author

Hello, I have updated the if condition as per @SunMarc's suggestion.
Do we have any updates regarding the update at each last step ?

@kibitzing
Copy link
Contributor Author

@SunMarc
Is there any update on this PR?
If any additional work is needed, please let me know!

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I feel this is a great simplification all around :)

@muellerzr muellerzr requested a review from SunMarc October 24, 2024 13:39
@muellerzr
Copy link
Contributor

(Last bit is resolving the conflicts)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the merge comflits. We did a lot of modification wrt to grad acc and it makes sense to have this now ! Feel free to ask us any question you have !

Comment on lines 2375 to 2462
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)

if (
total_batched_samples % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to perform the grad acc at the end of the dataloader but in trainer, we will set self.accelerator.gradient_state._set_sync_gradients by ourselves and not rely on the values set by accelerate.accumulate

Copy link
Contributor Author

@kibitzing kibitzing Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then I'll stick with the do_sync_step in the main branch code, and just replace the total_batched_samples with step.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right !

@@ -4786,8 +4769,6 @@ def create_accelerator_and_postprocess(self):
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

grad_acc_kwargs["sync_with_dataloader"] = False
Copy link
Contributor Author

@kibitzing kibitzing Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, since we're setting self.accelerator.gradient_state._set_sync_gradients by ourselves in the trainer, would it be safer to keep it set to False?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's keep it set to False !

@kibitzing
Copy link
Contributor Author

kibitzing commented Oct 25, 2024

Hello, @SunMarc @muellerzr

Here is a summary of the new commits:

  1. I merged the main branch and resolved conflicts.
  2. I replaced total_batched_samples with (step + 1) and removed steps_in_epoch <= args.gradient_accumulation_steps condition to simplify, as we discussed before

I see that in accelerate, we only turn the gradient sync when we are at the end of the dataloader or step+1 % args.gradient_accumulation_steps == 0. Potentially, we can even remove the condition that we placed in transformers.

  1. I reverted the previous changes and set grad_acc_kwargs["sync_with_dataloader"] = False as SunMarc mentioned,

we will set self.accelerator.gradient_state._set_sync_gradients by ourselves and not rely on the values set by accelerate.accumulate

If there are any issues or additional modifications needed, please let me know!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot better ! Thanks for your patience @kibitzing. If you are happy with the changes, feel free to merge the PR @muellerzr

@muellerzr
Copy link
Contributor

failing tests come from main, we are in limbo 🫠

@muellerzr
Copy link
Contributor

cc @ydshieh

@kibitzing
Copy link
Contributor Author

Hello @muellerzr, I merged the main branch and it looks like it passed the tests! 😄

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@muellerzr muellerzr merged commit dca93ca into huggingface:main Oct 31, 2024
24 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* replace total_batched_samples with step while counting grad accum step

* remove unused variable

* simplify condition for update step

* fix format by ruff

* simplify update step condition using accelerator.sync_gradients

* simplify update condition using do_sync_step

* remove print for test

---------

Co-authored-by: Zach Mueller <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Step shifting using total_batched_samples for gradient_accumulation_steps counting
5 participants