Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle actvitation checkpointing args that are None or non-tensors #660

Merged
merged 2 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ def forward(ctx, run_function, *args):

inputs = []
for i, item in enumerate(args[:-1]):
if not torch.is_tensor(item):
inputs.append(item)
continue

partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
Expand Down Expand Up @@ -413,7 +417,12 @@ def forward(ctx, run_function, *args):
inputs.append(args[-1])

#just in case something funky is happening such as reuse of inputs
inputs_cuda = [item.to(cuda_device) for item in args]
inputs_cuda = []
for item in args:
if torch.is_tensor(item):
inputs_cuda.append(item.to(cuda_device))
else:
inputs_cuda.append(item)

# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
Expand All @@ -439,6 +448,10 @@ def forward(ctx, run_function, *args):
if PARTITION_ACTIVATIONS:
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not torch.is_tensor(arg):
new_args.append(arg)
continue

size = torch.tensor(arg.size())

arg.data = inp.data
Expand Down Expand Up @@ -573,7 +586,14 @@ def backward(ctx, *grads):
timers.log(['backward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return (None, ) + tuple(inp.grad for inp in detached_inputs)
ret_list = [None] # first None for ctx
for inp in detached_inputs:
if torch.is_tensor(inp):
ret_list.append(inp.grad)
else:
ret_list.append(None)

return tuple(ret_list)


def checkpoint(function, *args):
Expand Down
30 changes: 27 additions & 3 deletions tests/unit/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _compute(module, *inputs, do_checkpoint=False):

sum(o.sum() for o in outputs if o.requires_grad).backward()
grads = [p.grad for p in module.parameters()]
input_grads = [inp.grad for inp in inputs]
input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]

return {
'outputs': outputs,
Expand All @@ -32,6 +32,18 @@ def _compute(module, *inputs, do_checkpoint=False):
}


def _prep_inputs(*inputs):
_inputs = []

for inp in inputs:
inp = deepcopy(inp)
if torch.is_tensor(inp):
inp = inp.cuda()
_inputs.append(inp)

return tuple(_inputs)


# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
Expand All @@ -43,11 +55,11 @@ def _test_activation_checkpoint(module, *inputs):
module.eval()

module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
inputs_ = _prep_inputs(*inputs)
base = _compute(module_, *inputs_, do_checkpoint=False)

module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
inputs_ = _prep_inputs(*inputs)
test = _compute(module_, *inputs_, do_checkpoint=True)

for group in base.keys():
Expand Down Expand Up @@ -155,3 +167,15 @@ def test_ckpt_inputs2_outputs3(mask):
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, mask)


class DropMaskLinear(torch.nn.Linear):
def forward(self, x, mask):
return super().forward(x)


def test_ckpt_arg_none():
module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM)
inputs = (torch.rand(HIDDEN_DIM), None)
inputs[0].requires_grad = True
_test_activation_checkpoint(module, *inputs)