diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 911b43bfc9bd..d1ecd9288878 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -173,7 +173,7 @@ def backward(ctx, *args): class PostBackwardFunction(torch.autograd.Function): @staticmethod - def forward(ctx, module, pre_backward_function, output): + def forward(ctx, module, post_backward_function, output): ctx.module = module if output.requires_grad: #TODO SOME TIMES post backward does not seem to be triggered debug in detail @@ -185,7 +185,7 @@ def forward(ctx, module, pre_backward_function, output): #if module.ds_grads_remaining == 0: # print(f"Before Forward: {ctx.module.__class__.__name__}") module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function + ctx.post_backward_function = post_backward_function output = output.detach() return output @@ -193,7 +193,7 @@ def forward(ctx, module, pre_backward_function, output): def backward(ctx, *args): ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) + ctx.post_backward_function(ctx.module) #print(f"After Backward: {ctx.module.__class__.__name__}") return (None, None) + args