Skip to content

Commit

Permalink
Revert "Activation checkpointing bugfix and unit tests (microsoft#420)…
Browse files Browse the repository at this point in the history
…" (microsoft#422)

This reverts commit 01b6e27.

Co-authored-by: Shaden Smith <[email protected]>
  • Loading branch information
jeffra and ShadenSmith authored Sep 18, 2020
1 parent 01b6e27 commit a74a604
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 160 deletions.
25 changes: 12 additions & 13 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,13 @@ def forward(ctx, run_function, *args):
if SYNCHRONIZE:
torch.cuda.synchronize()

# Tensors returned from forward() may not be differentiable.
if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else:
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
# Tensors returned from forward() may not be differentiable, e.g., attention mask
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
return outputs

@staticmethod
def backward(ctx, *grads):
def backward(ctx, *args):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
Expand Down Expand Up @@ -556,15 +553,17 @@ def backward(ctx, *grads):
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )

# Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable.
# Go over args and build the list of gradient tensors. This is usually just args,
# but if the forward pass returns tensors that do not require_grad then we should
# adjust the arguments to autograd.backward() too. This happens when forward()
# returns indices or a mask (such as an attention mask).
# We skip the first needs_input_grad because it corresponds to run_function.
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs, grads):
if out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
for idx, need_grad in enumerate(ctx.needs_input_grad[1:]):
if need_grad:
output_tensors.append(outputs[idx])
grad_tensors.append(args[idx])

torch.autograd.backward(output_tensors, grad_tensors)

Expand Down
147 changes: 0 additions & 147 deletions tests/unit/test_activation_checkpointing.py

This file was deleted.

0 comments on commit a74a604

Please sign in to comment.