Skip to content

Commit

Permalink
activation checkpointing for transformer engine with deepspeed (bigsc…
Browse files Browse the repository at this point in the history
  • Loading branch information
jomayeri authored Sep 22, 2023
1 parent 58737b3 commit b4fcd6c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import parallel_state, tensor_parallel
from megatron.core import parallel_state, tensor_parallel, mpu
from megatron.core.enums import ModelType
from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
Expand Down Expand Up @@ -1746,12 +1746,17 @@ def custom_forward(*args, **kwargs):
moe_losses = []
for index in range(start, end):
layer = self._get_layer(index)
x_, moe_loss = layer(x_, *args, **kwargs)
output = layer(x_, *args, **kwargs)
if isinstance(output, tuple):
x_, moe_loss = output
else:
x_ = output
moe_loss = torch.tensor(0.0, device=x_.device, dtype=x_.dtype, requires_grad=True)
moe_losses.append(moe_loss)
return (x_, *moe_losses)
return custom_forward

if args.deepspeed:
if args.deepspeed and args.deepspeed_activation_checkpointing:
moe_losses = []
# Make sure memory is freed.
tensor_parallel.reset_checkpointed_activations_memory_buffer()
Expand Down

0 comments on commit b4fcd6c

Please sign in to comment.