From 44bd538b110ce0e8fc69626854631c3aee0dc094 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 6 Jan 2021 11:03:35 -0800 Subject: [PATCH] Module replacement support (#586) Co-authored-by: Reza Yazdani Co-authored-by: Olatunji Ruwase --- deepspeed/module_inject/__init__.py | 0 deepspeed/module_inject/inject.py | 122 ++++++++++++ deepspeed/module_inject/replace_module.py | 192 +++++++++++++++++++ deepspeed/ops/__init__.py | 3 + deepspeed/ops/module_inject.py | 216 ++++++++++++++++++++++ deepspeed/ops/transformer/transformer.py | 64 +++++-- deepspeed/runtime/utils.py | 19 +- docs/_tutorials/bert-pretraining.md | 4 +- docs/_tutorials/transformer_kernel.md | 4 +- tests/unit/test_cuda_backward.py | 9 +- tests/unit/test_cuda_forward.py | 15 +- 11 files changed, 601 insertions(+), 47 deletions(-) create mode 100755 deepspeed/module_inject/__init__.py create mode 100755 deepspeed/module_inject/inject.py create mode 100755 deepspeed/module_inject/replace_module.py mode change 100644 => 100755 deepspeed/ops/__init__.py create mode 100755 deepspeed/ops/module_inject.py diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py new file mode 100755 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/module_inject/inject.py b/deepspeed/module_inject/inject.py new file mode 100755 index 000000000000..a601ef10e1d2 --- /dev/null +++ b/deepspeed/module_inject/inject.py @@ -0,0 +1,122 @@ +import copy +import torch +from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig + + +def module_inject(layer_obj, + model, + config, + micro_batch_size, + max_seq_length, + seed, + preln, + fp16=True): + for name, child in model.named_children(): + if isinstance(child, layer_obj): + print('REPLACING BertLayer') + + cuda_config = DeepSpeedTransformerConfig( + batch_size=micro_batch_size, + max_seq_length=max_seq_length, + hidden_size=config.hidden_size, + heads=config.num_attention_heads, + attn_dropout_ratio=config.attention_probs_dropout_prob, + hidden_dropout_ratio=config.hidden_dropout_prob, + num_hidden_layers=config.num_hidden_layers, + initializer_range=config.initializer_range, + seed=seed, + fp16=fp16, + pre_layer_norm=preln) + + new_module = DeepSpeedTransformerLayer(cuda_config) + + # copy relevant state from child -> new module + qw = child.attention.self.query.weight + qb = child.attention.self.query.bias + kw = child.attention.self.key.weight + kb = child.attention.self.key.bias + vw = child.attention.self.value.weight + vb = child.attention.self.value.bias + + qkvw = torch.cat((qw, kw, vw), 0) + qkvb = torch.cat((qb, kb, vb), 0) + + new_module.attn_qkvw.data = qkvw + new_module.attn_qkvb.data = qkvb + new_module.attn_ow.data = child.attention.output.dense.weight + new_module.attn_ob.data = child.attention.output.dense.bias + if preln: + attention_layerNorm = child.PostAttentionLayerNorm + else: + attention_layerNorm = child.attention.output.LayerNorm + new_module.attn_nw.data = attention_layerNorm.weight + new_module.attn_nb.data = attention_layerNorm.bias + if preln: + intermediate_FF = child.intermediate.dense_act + else: + intermediate_FF = child.intermediate.dense + new_module.inter_w.data = intermediate_FF.weight + new_module.inter_b.data = intermediate_FF.bias + new_module.output_w.data = child.output.dense.weight + new_module.output_b.data = child.output.dense.bias + if preln: + transformer_LayerNorm = child.PreAttentionLayerNorm + else: + transformer_LayerNorm = child.output.LayerNorm + new_module.norm_w.data = transformer_LayerNorm.weight + new_module.norm_b.data = transformer_LayerNorm.bias + + setattr(model, name, copy.deepcopy(new_module)) + + else: + module_inject(layer_obj, + child, + config, + micro_batch_size, + max_seq_length, + seed, + preln, + fp16) + + return model + + +def test_hi(): + from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN + from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN + from turing.nvidia_modelingpreln import BertLayer + bert_model_config = { + "vocab_size_or_config_json_file": 119547, + "hidden_size": 1024, + "num_hidden_layers": 1, + "num_attention_heads": 16, + "intermediate_size": 4096, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 2, + "initializer_range": 0.02 + } + bert_config = BertConfigPreLN(**bert_model_config) + base_model = BertForQuestionAnsweringPreLN(bert_config, args=None) + + #base_model = LinearStack() + + test_model = copy.deepcopy(base_model) + test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234) + + print('BASE', base_model) + print('TEST', test_model) + + #base_model.eval() + #test_model.eval() + + #test_input = torch.rand(1, base_model.input_dim) + + #base_output = base_model(test_input) + #test_output = test_model(test_input) + # + #assert torch.allclose(base_output, test_output, atol=3e-8) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py new file mode 100755 index 000000000000..5274d3c77f84 --- /dev/null +++ b/deepspeed/module_inject/replace_module.py @@ -0,0 +1,192 @@ +import copy +import torch +import deepspeed + + +def replace_transformer_layer(orig_layer_impl, + model, + micro_batch_size, + bert_config, + seed, + max_seq_length, + preln=False, + fp16=True, + huggingface=False, + local_rank=-1): + """ Replace bert-style transformer layers with DeepSpeed's transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + micro_batch_size (int): micro batch size per gpu used during training/eval + bert_config (dict): model config containing hidden size, attention heads, etc. + seed (int): random seed value + max_seq_length (int): max sequence length for training + preln (bool): does the original layer implementation do pre or post layer norm? + fp16 (bool): fp16 or fp32 + huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes) + + Returns: + Updated nn.module with replaced transformer layers + """ + def replace_fn(child): + transformer_config = deepspeed.DeepSpeedTransformerConfig( + batch_size=micro_batch_size, + max_seq_length=max_seq_length, + hidden_size=bert_config.hidden_size, + heads=bert_config.num_attention_heads, + attn_dropout_ratio=bert_config.attention_probs_dropout_prob, + hidden_dropout_ratio=bert_config.hidden_dropout_prob, + num_hidden_layers=bert_config.num_hidden_layers, + initializer_range=bert_config.initializer_range, + seed=seed, + fp16=fp16, + pre_layer_norm=preln, + huggingface=huggingface, + local_rank=local_rank) + new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) + + # copy relevant state from child -> new module + qw = child.attention.self.query.weight + qb = child.attention.self.query.bias + kw = child.attention.self.key.weight + kb = child.attention.self.key.bias + vw = child.attention.self.value.weight + vb = child.attention.self.value.bias + + qkvw = torch.cat((qw, kw, vw), 0) + qkvb = torch.cat((qb, kb, vb), 0) + + #qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0) + #qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0) + + new_module.attn_qkvw.data = qkvw + new_module.attn_qkvb.data = qkvb + new_module.attn_ow.data = child.attention.output.dense.weight + new_module.attn_ob.data = child.attention.output.dense.bias + if preln: + attention_layernorm = child.PostAttentionLayerNorm + else: + attention_layernorm = child.attention.output.LayerNorm + new_module.attn_nw.data = attention_layernorm.weight + new_module.attn_nb.data = attention_layernorm.bias + if preln: + intermediate_ff = child.intermediate.dense_act + else: + intermediate_ff = child.intermediate.dense + new_module.inter_w.data = intermediate_ff.weight + new_module.inter_b.data = intermediate_ff.bias + new_module.output_w.data = child.output.dense.weight + new_module.output_b.data = child.output.dense.bias + if preln: + transformer_layernorm = child.PreAttentionLayerNorm + else: + transformer_layernorm = child.output.LayerNorm + new_module.norm_w.data = transformer_layernorm.weight + new_module.norm_b.data = transformer_layernorm.bias + return new_module + + return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn) + + +def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False): + """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + bert_config (dict): model config containing hidden size, attention heads, etc. + + Returns: + Updated nn.module with original bert-style transformer layers + """ + def replace_fn(child): + #from turing.nvidia_modelingpreln import BertLayer + orig_module = orig_layer_impl(bert_config) + + # copy relevant state from child -> original module + qkvw = child.attn_qkvw.data + qkvb = child.attn_qkvb.data + + qw, kw, vw = torch.chunk(qkvw, 3, axis=0) + qb, kb, vb = torch.chunk(qkvb, 3, axis=0) + + orig_module.attention.self.query.weight.data = qw + orig_module.attention.self.query.bias.data = qb + orig_module.attention.self.key.weight.data = kw + orig_module.attention.self.key.bias.data = kb + orig_module.attention.self.value.weight.data = vw + orig_module.attention.self.value.bias.data = vb + + orig_module.attention.output.dense.weight.data = child.attn_ow.data + orig_module.attention.output.dense.bias.data = child.attn_ob.data + + attn_ln_w = child.attn_nw.data + attn_ln_b = child.attn_nb.data + if preln: + orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w + orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b + else: + orig_module.attention.output.LayerNorm.weight.data = attn_ln_w + orig_module.attention.output.LayerNorm.bias.data = attn_ln_b + + inter_ff_w = child.inter_w.data + inter_ff_b = child.inter_b.data + if preln: + orig_module.intermediate.dense_act.weight.data = inter_ff_w + orig_module.intermediate.dense_act.bias.data = inter_ff_b + else: + orig_module.intermediate.dense.weight.data = inter_ff_w + orig_module.intermediate.dense.bias.data = inter_ff_b + + orig_module.output.dense.weight.data = child.output_w.data + orig_module.output.dense.bias.data = child.output_b.data + + transformer_ln_w = child.norm_w.data + transformer_ln_b = child.norm_b.data + if preln: + orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w + orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b + else: + orig_module.output.LayerNorm.weight.data = transformer_ln_w + orig_module.output.LayerNorm.bias.data = transformer_ln_b + return orig_module + + return replace_module(model=model, + orig_class=deepspeed.DeepSpeedTransformerLayer, + replace_fn=replace_fn) + + +def replace_module(model, orig_class, replace_fn): + """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``. + Arguments: + model (torch.nn.Module): the model to augment + orig_class (torch.nn.Module): the module to search for + replace_fn (method): a method to convert instances of ``orig_class`` to the + desired type and return a new instance. + + Returns: + A modified ``model``. + """ + policy = {orig_class: replace_fn} + return _replace_module(model, policy) + + +def _replace_module(model, policies): + """ Traverse model's children recursively and apply any transformations in ``policies``. + Arguments: + model (torch.nn.Module): model to augment + policies (dict): Mapping of source class to replacement function. + + Returns: + Modified ``model``. + """ + for name, child in model.named_children(): + if child.__class__ in policies: + orig = repr(child) + setattr(model, name, policies[child.__class__](child)) + new = getattr(model, name) + else: + _replace_module(child, policies) + + return model diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py old mode 100644 new mode 100755 index 8aec76267ed3..e6fd81fb5a13 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -3,4 +3,7 @@ from . import sparse_attention from . import transformer +from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from .module_inject import replace_module + from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/deepspeed/ops/module_inject.py b/deepspeed/ops/module_inject.py new file mode 100755 index 000000000000..6b0d47cb6733 --- /dev/null +++ b/deepspeed/ops/module_inject.py @@ -0,0 +1,216 @@ +import copy +import torch +import deepspeed + +from deepspeed.ops import DeepSpeedTransformerConfig + + +def _copy_child_transformer_state(new_module, orig_child, pre_layer_norm): + # copy relevant state from original child -> new module + qw = orig_child.attention.self.query.weight + qb = orig_child.attention.self.query.bias + kw = orig_child.attention.self.key.weight + kb = orig_child.attention.self.key.bias + vw = orig_child.attention.self.value.weight + vb = orig_child.attention.self.value.bias + + qkvw = torch.cat((qw, kw, vw), 0) + qkvb = torch.cat((qb, kb, vb), 0) + + #qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0) + #qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0) + + new_module.attn_qkvw.data = qkvw + new_module.attn_qkvb.data = qkvb + new_module.attn_ow.data = orig_child.attention.output.dense.weight + new_module.attn_ob.data = orig_child.attention.output.dense.bias + if pre_layer_norm: + attention_layernorm = orig_child.PostAttentionLayerNorm + else: + attention_layernorm = orig_child.attention.output.LayerNorm + new_module.attn_nw.data = attention_layernorm.weight + new_module.attn_nb.data = attention_layernorm.bias + if pre_layer_norm: + intermediate_ff = orig_child.intermediate.dense_act + else: + intermediate_ff = orig_child.intermediate.dense + new_module.inter_w.data = intermediate_ff.weight + new_module.inter_b.data = intermediate_ff.bias + new_module.output_w.data = orig_child.output.dense.weight + new_module.output_b.data = orig_child.output.dense.bias + if pre_layer_norm: + transformer_layernorm = orig_child.PreAttentionLayerNorm + else: + transformer_layernorm = orig_child.output.LayerNorm + new_module.norm_w.data = transformer_layernorm.weight + new_module.norm_b.data = transformer_layernorm.bias + + +def _replace_transformer_layer(orig_layer_impl, model, transformer_config): + """ Replace bert-style transformer layers with DeepSpeed's transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc. + Returns: + Updated nn.module with replaced transformer layers + """ + def replace_fn(child): + new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) + _copy_child_transformer_state(new_module, + child, + transformer_config.pre_layer_norm) + + return new_module + + return _replace_module(model=model, + orig_class=orig_layer_impl, + replace_fn=replace_fn) + + +def replace_module(orig_module_impl, model, replacement_module_config): + """ Replace client module + Arguments: + orig_module_impl (torch.nn.Module): original module implementation to replace, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) . + + Returns: + Updated nn.module with replaced modules + """ + assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \ + 'Only DeepSpeedTransformerConfig is currently supported as replacement config' + + return _replace_transformer_layer(orig_layer_impl=orig_module_impl, + model=model, + transformer_config=replacement_module_config) + + +def _revert_transformer_layer(orig_layer_impl, model, bert_config, transformer_config): + """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + bert_config (dict): model config containing hidden size, attention heads, etc. + transformer_config (dict): deepspeed tranformer config used for replacement + + Returns: + Updated nn.module with original bert-style transformer layers + """ + def replace_fn(child): + #from turing.nvidia_modelingpreln import BertLayer + orig_module = orig_layer_impl(bert_config) + + # copy relevant state from child -> original module + qkvw = child.attn_qkvw.data + qkvb = child.attn_qkvb.data + + qw, kw, vw = torch.chunk(qkvw, 3, axis=0) + qb, kb, vb = torch.chunk(qkvb, 3, axis=0) + + orig_module.attention.self.query.weight.data = qw + orig_module.attention.self.query.bias.data = qb + orig_module.attention.self.key.weight.data = kw + orig_module.attention.self.key.bias.data = kb + orig_module.attention.self.value.weight.data = vw + orig_module.attention.self.value.bias.data = vb + + orig_module.attention.output.dense.weight.data = child.attn_ow.data + orig_module.attention.output.dense.bias.data = child.attn_ob.data + + attn_ln_w = child.attn_nw.data + attn_ln_b = child.attn_nb.data + if transformer_config.pre_layer_norm: + orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w + orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b + else: + orig_module.attention.output.LayerNorm.weight.data = attn_ln_w + orig_module.attention.output.LayerNorm.bias.data = attn_ln_b + + inter_ff_w = child.inter_w.data + inter_ff_b = child.inter_b.data + if transformer_config.pre_layer_norm: + orig_module.intermediate.dense_act.weight.data = inter_ff_w + orig_module.intermediate.dense_act.bias.data = inter_ff_b + else: + orig_module.intermediate.dense.weight.data = inter_ff_w + orig_module.intermediate.dense.bias.data = inter_ff_b + + orig_module.output.dense.weight.data = child.output_w.data + orig_module.output.dense.bias.data = child.output_b.data + + transformer_ln_w = child.norm_w.data + transformer_ln_b = child.norm_b.data + if transformer_config.pre_layer_norm: + orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w + orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b + else: + orig_module.output.LayerNorm.weight.data = transformer_ln_w + orig_module.output.LayerNorm.bias.data = transformer_ln_b + return orig_module + + return _replace_module(model=model, + orig_class=deepspeed.DeepSpeedTransformerLayer, + replace_fn=replace_fn) + + +def revert_module(orig_module_impl, + model, + orig_module_config, + replacement_module_config): + """ Revert DeepSpeed's module back to original client module + Arguments: + orig_module_impl (torch.nn.Module): the original module that was replaced, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + orig_module_config (dict): original module configuration + replacement_module_config (dict): replacement deepspeed module configuration + + Returns: + Updated nn.module with original bert-style transformer layers + """ + assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \ + 'Only DeepSpeedTransformerConfig is currently supported as replacement config' + + return _revert_transformer_layer(orig_layer_impl=orig_module_impl, + model=model, + bert_config=orig_module_config, + transformer_config=replacement_module_config) + + +def _replace_module(model, orig_class, replace_fn): + """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``. + Arguments: + model (torch.nn.Module): the model to augment + orig_class (torch.nn.Module): the module to search for + replace_fn (method): a method to convert instances of ``orig_class`` to the + desired type and return a new instance. + + Returns: + A modified ``model``. + """ + policy = {orig_class: replace_fn} + return _replace_module_using_policies(model, policy) + + +def _replace_module_using_policies(model, policies): + """ Traverse model's children recursively and apply any transformations in ``policies``. + Arguments: + model (torch.nn.Module): model to augment + policies (dict): Mapping of source class to replacement function. + + Returns: + Modified ``model``. + """ + for name, child in model.named_children(): + if child.__class__ in policies: + orig = repr(child) + setattr(model, name, policies[child.__class__](child)) + new = getattr(model, name) + else: + _replace_module_using_policies(child, policies) + + return model diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index ea4b98848d3c..f0979f2e3f2a 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -87,6 +87,10 @@ class DeepSpeedTransformerConfig(TransformerConfig): that by enabling it, the pretraining tasks such as BERT are not affected and can obtain a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend to turn it off in order to be able to reproduce the same result through the regular kernel execution. + + huggingface: Enbale if using the HuggingFace interface style for sending out the forward results. + + training: Enable for training rather than inference. """ def __init__(self, batch_size=-1, @@ -105,7 +109,9 @@ def __init__(self, gelu_checkpoint=False, adjust_init_range=True, attn_dropout_checkpoint=False, - stochastic_mode=False): + stochastic_mode=False, + huggingface=False, + training=True): super(DeepSpeedTransformerConfig, self).__init__( batch_size, @@ -124,10 +130,11 @@ def __init__(self, self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required self.adjust_init_range = adjust_init_range self.test_gemm = False - self.training = True + self.training = training self.is_grad_enabled = True self.attn_dropout_checkpoint = attn_dropout_checkpoint self.stochastic_mode = stochastic_mode + self.huggingface = huggingface @classmethod def from_dict(cls, json_object): @@ -252,7 +259,7 @@ def forward(ctx, norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"])) norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"])) - if config.is_grad_enabled: + if config.is_grad_enabled and config.training: if (config.pre_layer_norm and config.normalize_invertible): ctx.save_for_backward(input_mask, attn_qkvw, @@ -313,7 +320,11 @@ def forward(ctx, if inp_size[1] % 16 != 0: output = torch.narrow(output, 1, 0, inp_size[1]) - return output + + if config.huggingface: + return (output, ) # outputs -> (output) : outputs[0] = output + else: + return output @staticmethod def backward(ctx, grad_output): @@ -412,6 +423,25 @@ def backward(ctx, grad_output): norm_w, norm_b) + # This appears to be an effective way to release context memory + ctx.qkv_tf = None + ctx.soft_inp = None + ctx.ctx_bufB = None + ctx.gelu_inp = None + ctx.ff2_inp = None + ctx.attn_o_inp = None + ctx.ff1_inp = None + ctx.add_res = None + ctx.inp_norm = None + ctx.config = None + ctx.attn_layer_norm_mean = None + ctx.layer_norm_mean = None + ctx.attn_prob_dropout_mask = None + ctx.attn_output_dropout_mask = None + ctx.layer_output_dropout_mask = None + ctx.attn_layer_norm_var = None + ctx.layer_norm_var = None + if grad_output_shape[1] % 16 != 0: grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) @@ -438,21 +468,24 @@ def backward(ctx, grad_output): class DeepSpeedTransformerLayer(nn.Module): """Initialize the DeepSpeed Transformer Layer. + Static variable: + layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, + e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. Arguments: - layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers, - layer_id will be 0,1,2...23 when each layer object is instantiated - config: An object of DeepSpeedTransformerConfig initial_weights: Optional: Only used for unit test initial_biases: Optional: Only used for unit test """ - def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): + layer_id = 0 + + def __init__(self, config, initial_weights=None, initial_biases=None): super(DeepSpeedTransformerLayer, self).__init__() self.config = config - self.config.layer_id = layer_id + self.config.layer_id = DeepSpeedTransformerLayer.layer_id + DeepSpeedTransformerLayer.layer_id = DeepSpeedTransformerLayer.layer_id + 1 print("DeepSpeed Transformer config is ", self.config.__dict__) @@ -548,11 +581,18 @@ def init_transformer_weights(self, adjust_init_range=False): self.norm_w.data.fill_(1.0) self.norm_b.data.zero_() - def forward(self, input, input_mask, grads=None): + def forward(self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + grads=None): self.config.training = self.training self.config.is_grad_enabled = torch.is_grad_enabled() - return DeepSpeedTransformerFunction.apply(input, - input_mask, + return DeepSpeedTransformerFunction.apply(hidden_states, + attention_mask, self, grads, self.config.layer_id, diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7a7eb77f87fb..05285b328851 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -530,21 +530,10 @@ def see_memory_usage(message): # Print message except when distributed but not rank 0 logger.info(message) logger.info( - "Memory Allocated %s GigaBytes ", - torch.cuda.memory_allocated() / (1024 * 1024 * 1024), - ) - logger.info( - "Max Memory Allocated %s GigaBytes", - torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), - ) - logger.info( - "Cache Allocated %s GigaBytes", - torch.cuda.memory_cached() / (1024 * 1024 * 1024), - ) - logger.info( - "Max cache Allocated %s GigaBytes", - torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), - ) + f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ + Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ + CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \ + Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ") def call_to_str(base, *args, **kwargs): diff --git a/docs/_tutorials/bert-pretraining.md b/docs/_tutorials/bert-pretraining.md index 03462e893b07..0791fb3308fe 100755 --- a/docs/_tutorials/bert-pretraining.md +++ b/docs/_tutorials/bert-pretraining.md @@ -284,10 +284,10 @@ transformer layers using DeepSpeed transformer kernel as below. gelu_checkpoint=args.gelu_checkpoint, stochastic_mode=True) - self.layer = nn.ModuleList([copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) for i in range(config.num_hidden_layers)]) + layer = DeepSpeedTransformerLayer(cuda_config) else: layer = BertLayer(config) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) ``` All configuration settings come from the DeepSpeed configuration file and command arguments and thus we must pass the `args` variable to here in this model. diff --git a/docs/_tutorials/transformer_kernel.md b/docs/_tutorials/transformer_kernel.md index 26e88406920e..9dbcf26e2a12 100755 --- a/docs/_tutorials/transformer_kernel.md +++ b/docs/_tutorials/transformer_kernel.md @@ -43,8 +43,8 @@ config = DeepSpeedTransformerConfig(batch_size = 64, normalize_invertible=False, gelu_checkpoint=False) self.layer = nn.ModuleList([ - copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) - for i in range(config.num_hidden_layers) + copy.deepcopy(DeepSpeedTransformerLayer(cuda_config)) + for _ in range(config.num_hidden_layers) ]) ``` ### Transformer kernel Parameters diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index fd3f9887ad42..eca853abf569 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -83,11 +83,10 @@ def __init__(self, config, weights, biases): super(DSEncoder, self).__init__() self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.layer = nn.ModuleList([ - copy.deepcopy(DeepSpeedTransformerLayer(i, - config, + copy.deepcopy(DeepSpeedTransformerLayer(config, weights, biases)) - for i in range(config.num_hidden_layers) + for _ in range(config.num_hidden_layers) ]) self.grads = [] self.pre_or_post = config.pre_layer_norm @@ -122,7 +121,9 @@ def custom_forward(*inputs): # decoder layers else: for i, layer_module in enumerate(self.layer): - hidden_states = layer_module(hidden_states, attention_mask, self.grads) + hidden_states = layer_module(hidden_states, + attention_mask, + grads=self.grads) hidden_states.register_hook( lambda x, self=self: self.grads.append([x, diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 88cb90848603..5add5e152a91 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -48,11 +48,10 @@ def __init__(self, config, weights, biases): super(DSEncoder, self).__init__() self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.layer = nn.ModuleList([ - copy.deepcopy(DeepSpeedTransformerLayer(i, - config, + copy.deepcopy(DeepSpeedTransformerLayer(config, weights, biases)) - for i in range(config.num_hidden_layers) + for _ in range(config.num_hidden_layers) ]) self.grads = [] self.pre_or_post = config.pre_layer_norm @@ -88,11 +87,6 @@ def custom_forward(*inputs): else: for i, layer_module in enumerate(self.layer): hidden_states = layer_module(hidden_states, attention_mask) - hidden_states.register_hook( - lambda x, - i=i, - self=self: self.grads.append([x, - "hidden_state"])) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) @@ -103,9 +97,6 @@ def custom_forward(*inputs): all_encoder_layers.append(hidden_states) return all_encoder_layers - def get_grads(self): - return self.grads - def create_models(ds_config): bert_config = BertConfig(vocab_size_or_config_json_file=119547, @@ -201,7 +192,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): output_all_encoded_layers=False, checkpoint_activations=False) - # check grads + # check forward evaluation check_equal(base_results, ds_results, atol=atol, verbose=verbose)