diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py old mode 100755 new mode 100644 index 3950e7eced20..1cc20cd3dfce --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -481,13 +481,16 @@ def forward(ctx, run_function, *args): if SYNCHRONIZE: torch.cuda.synchronize() - # Tensors returned from forward() may not be differentiable, e.g., attention mask - non_grad_outputs = [o for o in outputs if not o.is_floating_point()] + # Tensors returned from forward() may not be differentiable. + if torch.is_tensor(outputs): + non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] + else: + non_grad_outputs = [o for o in outputs if not o.is_floating_point()] ctx.mark_non_differentiable(*non_grad_outputs) return outputs @staticmethod - def backward(ctx, *args): + def backward(ctx, *grads): global timers #see_memory_usage("In backward", force=True) #removing pointers to the contiguous buffer memory @@ -553,17 +556,15 @@ def backward(ctx, *args): if isinstance(outputs, torch.Tensor): outputs = (outputs, ) - # Go over args and build the list of gradient tensors. This is usually just args, - # but if the forward pass returns tensors that do not require_grad then we should - # adjust the arguments to autograd.backward() too. This happens when forward() - # returns indices or a mask (such as an attention mask). - # We skip the first needs_input_grad because it corresponds to run_function. + # Construct arguments to autograd.backward(). + # This is usually just outputs and grads, but forward() can return tensors that + # are not differentiable. output_tensors = [] grad_tensors = [] - for idx, need_grad in enumerate(ctx.needs_input_grad[1:]): - if need_grad: - output_tensors.append(outputs[idx]) - grad_tensors.append(args[idx]) + for out, grad in zip(outputs, grads): + if out.requires_grad: + output_tensors.append(out) + grad_tensors.append(grad) torch.autograd.backward(output_tensors, grad_tensors) diff --git a/tests/unit/test_activation_checkpointing.py b/tests/unit/test_activation_checkpointing.py new file mode 100644 index 000000000000..8bb8ce8be3dc --- /dev/null +++ b/tests/unit/test_activation_checkpointing.py @@ -0,0 +1,157 @@ +# TODO: add tests with model parallelism for activation partitioning and other features. + +from copy import deepcopy + +import pytest + +import torch + +import deepspeed +ckpt = deepspeed.checkpointing.checkpoint + +from common import distributed_test + + +def _compute(module, *inputs, do_checkpoint=False): + if do_checkpoint: + outputs = ckpt(module, *inputs) + else: + outputs = module(*inputs) + + if torch.is_tensor(outputs): + outputs = (outputs, ) + + sum(o.sum() for o in outputs if o.requires_grad).backward() + grads = [p.grad for p in module.parameters()] + input_grads = [inp.grad for inp in inputs] + + return { + 'outputs': outputs, + 'module_grads': grads, + 'input_grads': input_grads, + } + + +# This is distributed because checkpoint() assumes that torch.distributed is initialized. +# torch.distributed is used with activation partitioning, but not for these simple cases. +@distributed_test(world_size=1) +def _test_activation_checkpoint(module, *inputs): + # Move to device + module.cuda() + + # Get rid of dropouts until we fork the RNG between tests. + module.eval() + + module_ = deepcopy(module) + inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs) + base = _compute(module_, *inputs_, do_checkpoint=False) + + module_ = deepcopy(module) + inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs) + test = _compute(module_, *inputs_, do_checkpoint=True) + + for group in base.keys(): + for b, t in zip(base[group], test[group]): + # Catch grad `None`s, etc. + if not torch.is_tensor(b): + assert b == t + elif b.is_floating_point(): + assert torch.allclose(b, t) + else: + assert torch.equal(b, t) + + +# +# Helpers +# + + +class MaskedLinear(torch.nn.Linear): + def forward(self, x, mask): + out = super().forward(x) + if mask.is_floating_point(): + out = out * mask + else: + # must cast BoolTensor in older torch versions + out = out * mask.type_as(out) + return out + + +class MaskedLinearSeq(MaskedLinear): + """Tests pipeline modules by also returning the mask.""" + def forward(self, x, mask): + return super().forward(x, mask), mask + + +class MaskedLinearSeqDup(MaskedLinearSeq): + """MaskedLinearSeq, but with more outputs than inputs and in a different order.""" + def forward(self, x, mask): + dup = x.clone().detach() * 1.38 # just an arbitrary scaling + x, mask = super().forward(x, mask) + return dup, x, mask + + +HIDDEN_DIM = 20 + + +def _mixed_mask(size=HIDDEN_DIM): + entries = torch.randn(size) + mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size)) + mask = mask.bool() + return mask + + +def _bool_to_float(btensor, dtype=torch.float32): + """Converts a torch.BoolTensor to an equivalent dtype. """ + ones = torch.ones(size=btensor.size(), dtype=dtype) + zeros = torch.zeros(size=btensor.size(), dtype=dtype) + return torch.where(btensor, ones, zeros) + + +# +# Tests +# + + +def test_ckpt_inputs1_outputs1(): + module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + inputs.requires_grad = True + _test_activation_checkpoint(module, inputs) + + +# both bool and float are important, as bool is not diffentiable +@pytest.mark.parametrize('mask', + [ + _mixed_mask(), + _bool_to_float(_mixed_mask()), + ]) +def test_ckpt_inputs2_outputs1(mask): + module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + inputs.requires_grad = True + _test_activation_checkpoint(module, inputs, mask) + + +@pytest.mark.parametrize('mask', + [ + _mixed_mask(), + _bool_to_float(_mixed_mask()), + ]) +def test_ckpt_inputs2_outputs2(mask): + module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + inputs.requires_grad = True + _test_activation_checkpoint(module, inputs, mask) + + +@pytest.mark.parametrize('mask', + [ + _mixed_mask(), + _bool_to_float(_mixed_mask()), + ]) +def test_ckpt_inputs2_outputs3(mask): + module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM) + inputs = torch.rand(HIDDEN_DIM) + inputs.requires_grad = True + _test_activation_checkpoint(module, inputs, mask)