Skip to content

Commit

Permalink
Fix activation checkpoint unit tests for GPU systems (microsoft#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaden Smith authored Sep 18, 2020
1 parent a74a604 commit a825f99
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 12 deletions.
25 changes: 13 additions & 12 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,16 @@ def forward(ctx, run_function, *args):
if SYNCHRONIZE:
torch.cuda.synchronize()

# Tensors returned from forward() may not be differentiable, e.g., attention mask
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
# Tensors returned from forward() may not be differentiable.
if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else:
non_grad_outputs = [o for o in outputs if not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
return outputs

@staticmethod
def backward(ctx, *args):
def backward(ctx, *grads):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
Expand Down Expand Up @@ -553,17 +556,15 @@ def backward(ctx, *args):
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )

# Go over args and build the list of gradient tensors. This is usually just args,
# but if the forward pass returns tensors that do not require_grad then we should
# adjust the arguments to autograd.backward() too. This happens when forward()
# returns indices or a mask (such as an attention mask).
# We skip the first needs_input_grad because it corresponds to run_function.
# Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable.
output_tensors = []
grad_tensors = []
for idx, need_grad in enumerate(ctx.needs_input_grad[1:]):
if need_grad:
output_tensors.append(outputs[idx])
grad_tensors.append(args[idx])
for out, grad in zip(outputs, grads):
if out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)

torch.autograd.backward(output_tensors, grad_tensors)

Expand Down
157 changes: 157 additions & 0 deletions tests/unit/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# TODO: add tests with model parallelism for activation partitioning and other features.

from copy import deepcopy

import pytest

import torch

import deepspeed
ckpt = deepspeed.checkpointing.checkpoint

from common import distributed_test


def _compute(module, *inputs, do_checkpoint=False):
if do_checkpoint:
outputs = ckpt(module, *inputs)
else:
outputs = module(*inputs)

if torch.is_tensor(outputs):
outputs = (outputs, )

sum(o.sum() for o in outputs if o.requires_grad).backward()
grads = [p.grad for p in module.parameters()]
input_grads = [inp.grad for inp in inputs]

return {
'outputs': outputs,
'module_grads': grads,
'input_grads': input_grads,
}


# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint(module, *inputs):
# Move to device
module.cuda()

# Get rid of dropouts until we fork the RNG between tests.
module.eval()

module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
base = _compute(module_, *inputs_, do_checkpoint=False)

module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
test = _compute(module_, *inputs_, do_checkpoint=True)

for group in base.keys():
for b, t in zip(base[group], test[group]):
# Catch grad `None`s, etc.
if not torch.is_tensor(b):
assert b == t
elif b.is_floating_point():
assert torch.allclose(b, t)
else:
assert torch.equal(b, t)


#
# Helpers
#


class MaskedLinear(torch.nn.Linear):
def forward(self, x, mask):
out = super().forward(x)
if mask.is_floating_point():
out = out * mask
else:
# must cast BoolTensor in older torch versions
out = out * mask.type_as(out)
return out


class MaskedLinearSeq(MaskedLinear):
"""Tests pipeline modules by also returning the mask."""
def forward(self, x, mask):
return super().forward(x, mask), mask


class MaskedLinearSeqDup(MaskedLinearSeq):
"""MaskedLinearSeq, but with more outputs than inputs and in a different order."""
def forward(self, x, mask):
dup = x.clone().detach() * 1.38 # just an arbitrary scaling
x, mask = super().forward(x, mask)
return dup, x, mask


HIDDEN_DIM = 20


def _mixed_mask(size=HIDDEN_DIM):
entries = torch.randn(size)
mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size))
mask = mask.bool()
return mask


def _bool_to_float(btensor, dtype=torch.float32):
"""Converts a torch.BoolTensor to an equivalent dtype. """
ones = torch.ones(size=btensor.size(), dtype=dtype)
zeros = torch.zeros(size=btensor.size(), dtype=dtype)
return torch.where(btensor, ones, zeros)


#
# Tests
#


def test_ckpt_inputs1_outputs1():
module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs)


# both bool and float are important, as bool is not diffentiable
@pytest.mark.parametrize('mask',
[
_mixed_mask(),
_bool_to_float(_mixed_mask()),
])
def test_ckpt_inputs2_outputs1(mask):
module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
[
_mixed_mask(),
_bool_to_float(_mixed_mask()),
])
def test_ckpt_inputs2_outputs2(mask):
module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
[
_mixed_mask(),
_bool_to_float(_mixed_mask()),
])
def test_ckpt_inputs2_outputs3(mask):
module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, mask)

0 comments on commit a825f99

Please sign in to comment.