diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 1cc20cd3dfce..3950e7eced20 100755 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -481,16 +481,13 @@ def forward(ctx, run_function, *args): if SYNCHRONIZE: torch.cuda.synchronize() - # Tensors returned from forward() may not be differentiable. - if torch.is_tensor(outputs): - non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] - else: - non_grad_outputs = [o for o in outputs if not o.is_floating_point()] + # Tensors returned from forward() may not be differentiable, e.g., attention mask + non_grad_outputs = [o for o in outputs if not o.is_floating_point()] ctx.mark_non_differentiable(*non_grad_outputs) return outputs @staticmethod - def backward(ctx, *grads): + def backward(ctx, *args): global timers #see_memory_usage("In backward", force=True) #removing pointers to the contiguous buffer memory @@ -556,15 +553,17 @@ def backward(ctx, *grads): if isinstance(outputs, torch.Tensor): outputs = (outputs, ) - # Construct arguments to autograd.backward(). - # This is usually just outputs and grads, but forward() can return tensors that - # are not differentiable. + # Go over args and build the list of gradient tensors. This is usually just args, + # but if the forward pass returns tensors that do not require_grad then we should + # adjust the arguments to autograd.backward() too. This happens when forward() + # returns indices or a mask (such as an attention mask). + # We skip the first needs_input_grad because it corresponds to run_function. output_tensors = [] grad_tensors = [] - for out, grad in zip(outputs, grads): - if out.requires_grad: - output_tensors.append(out) - grad_tensors.append(grad) + for idx, need_grad in enumerate(ctx.needs_input_grad[1:]): + if need_grad: + output_tensors.append(outputs[idx]) + grad_tensors.append(args[idx]) torch.autograd.backward(output_tensors, grad_tensors) diff --git a/tests/unit/test_activation_checkpointing.py b/tests/unit/test_activation_checkpointing.py deleted file mode 100644 index 78bb9309e82b..000000000000 --- a/tests/unit/test_activation_checkpointing.py +++ /dev/null @@ -1,147 +0,0 @@ -# TODO: add tests with model parallelism for activation partitioning and other features. - -from copy import deepcopy - -import pytest - -import torch - -import deepspeed -ckpt = deepspeed.checkpointing.checkpoint - -from common import distributed_test - - -def _compute(module, *inputs, do_checkpoint=False): - inputs = deepcopy(inputs) - module = deepcopy(module) - - if do_checkpoint: - outputs = ckpt(module, *inputs) - else: - outputs = module(*inputs) - - if torch.is_tensor(outputs): - outputs = (outputs, ) - - sum(o.sum() for o in outputs if o.requires_grad).backward() - grads = [p.grad for p in module.parameters()] - input_grads = [inp.grad for inp in inputs] - - return { - 'outputs': outputs, - 'module_grads': grads, - 'input_grads': input_grads, - } - - -# This is distributed because checkpoint() assumes that torch.distributed is initialized. -# torch.distributed is used with activation partitioning, but not for these simple cases. -@distributed_test(world_size=1) -def _test_activation_checkpoint(module, *inputs): - # Get rid of dropouts until we fork the RNG between tests. - module.eval() - - base = _compute(module, *inputs, do_checkpoint=False) - test = _compute(module, *inputs, do_checkpoint=True) - - for group in base.keys(): - for b, t in zip(base[group], test[group]): - # Catch grad `None`s, etc. - if not torch.is_tensor(b): - assert b == t - elif b.is_floating_point(): - assert torch.allclose(b, t) - else: - assert torch.equal(b, t) - - -# -# Helpers -# - - -class MaskedLinear(torch.nn.Linear): - def forward(self, x, mask): - out = super().forward(x) - return out * mask - - -class MaskedLinearSeq(MaskedLinear): - """Tests pipeline modules by also returning the mask.""" - def forward(self, x, mask): - return super().forward(x, mask), mask - - -class MaskedLinearSeqDup(MaskedLinearSeq): - """MaskedLinearSeq, but with more outputs than inputs and in a different order.""" - def forward(self, x, mask): - dup = x.clone().detach() * 1000 * mask - x, mask = super().forward(x, mask) - return dup, x, mask - - -HIDDEN_DIM = 20 - - -def _mixed_mask(size=HIDDEN_DIM): - entries = torch.randn(size) - mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size)) - mask = mask.bool() - return mask - - -def _bool_to_float(btensor, dtype=torch.float32): - """Converts a torch.BoolTensor to an equivalent dtype. """ - ones = torch.ones(size=btensor.size(), dtype=dtype) - zeros = torch.zeros(size=btensor.size(), dtype=dtype) - return torch.where(btensor, ones, zeros) - - -# -# Tests -# - - -def test_ckpt_inputs1_outputs1(): - module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM) - inputs = torch.rand(HIDDEN_DIM) - inputs.requires_grad = True - _test_activation_checkpoint(module, inputs) - - -# both bool and float are important, as bool is not diffentiable -@pytest.mark.parametrize('mask', - [ - _mixed_mask(), - _bool_to_float(_mixed_mask()), - ]) -def test_ckpt_inputs2_outputs1(mask): - module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM) - inputs = torch.rand(HIDDEN_DIM) - inputs.requires_grad = True - _test_activation_checkpoint(module, inputs, mask) - - -@pytest.mark.parametrize('mask', - [ - _mixed_mask(), - _bool_to_float(_mixed_mask()), - ]) -def test_ckpt_inputs2_outputs2(mask): - module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM) - inputs = torch.rand(HIDDEN_DIM) - inputs.requires_grad = True - _test_activation_checkpoint(module, inputs, mask) - - -@pytest.mark.parametrize('mask', - [ - _mixed_mask(), - _bool_to_float(_mixed_mask()), - ]) -def test_ckpt_inputs2_outputs3(mask): - module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM) - inputs = torch.rand(HIDDEN_DIM) - inputs.requires_grad = True - _test_activation_checkpoint(module, inputs, mask)