Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The PostBackwardFunction class should be more clearly named to distinguish it from the PreBackwardFunction class. #2548

Merged
merged 10 commits into from
Sep 6, 2023
6 changes: 3 additions & 3 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def backward(ctx, *args):
class PostBackwardFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, module, pre_backward_function, output):
def forward(ctx, module, post_backward_function, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
Expand All @@ -185,15 +185,15 @@ def forward(ctx, module, pre_backward_function, output):
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.pre_backward_function = pre_backward_function
ctx.post_backward_function = post_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.pre_backward_function(ctx.module)
ctx.post_backward_function(ctx.module)
#print(f"After Backward: {ctx.module.__class__.__name__}")
return (None, None) + args

Expand Down