Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Debug amax reductions in eval mode and async amax reductions #728

Closed
wants to merge 6 commits into from

Conversation

timmoon10
Copy link
Collaborator

This PR fixes two bugs related to amax reductions:

  1. In the forward pass, we only launch an amax reduction if a module is in training mode. However, we always launch amax reductions in the backward pass. This causes runtime errors in the backward pass of modules in evaluation mode, e.g. LoRA frozen layers. Confusing no_grad and evaluation mode seems to be a common mistake in PyTorch. This PR fixes this by checking if a module is in training mode in its backward pass, similar to how we do it in the forward pass.
  2. When asynchronous amax reductions are enabled, we currently sync the reduction after the first module's amax and scale update. I would appreciate sanity checking, since I would expect this to have caused non-deterministic numerical errors in the scale updates. This PR avoids this by making sure the async reduction is finished before any amax and scale updates. See Async amax reduction #118 (comment).

I've attempted to keep this PR small since #575 touches a lot of the amax reduction logic. In the future, I think it would be worthwhile reworking the async amax reductions since it currently doesn't do much overlapping (it is launched when entering fp8_autocast and synchronized before the first TE module's forward).

Do not update backward FP8 scales when in eval mode. Make sure to finish async amax reductions before scale update.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 added the bug Something isn't working label Mar 21, 2024
@timmoon10 timmoon10 requested review from ptrendx and ksivaman March 21, 2024 04:18
@timmoon10 timmoon10 changed the title [PyTorch] Debug async amax reductions and amax reductions in eval mode [PyTorch] Debugamax reductions in eval mode and async amax reductions Mar 21, 2024
@timmoon10 timmoon10 changed the title [PyTorch] Debugamax reductions in eval mode and async amax reductions [PyTorch] Debug amax reductions in eval mode and async amax reductions Mar 21, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@ksivaman
Copy link
Member

Note: #575 already revokes the async amax reduction and addresses these fixes (including overhauling the current system for amax reduction/update), and given these would land in the same release, shall we close this? @timmoon10

@timmoon10 timmoon10 marked this pull request as draft March 27, 2024 01:05
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 force-pushed the debug-lora-amax-reduction branch from 2235c17 to 05d7861 Compare March 28, 2024 02:59
@timmoon10 timmoon10 closed this Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants