forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Module replacement support (microsoft#586)
Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
- Loading branch information
1 parent
5ab1279
commit 44bd538
Showing
11 changed files
with
601 additions
and
47 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import copy | ||
import torch | ||
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig | ||
|
||
|
||
def module_inject(layer_obj, | ||
model, | ||
config, | ||
micro_batch_size, | ||
max_seq_length, | ||
seed, | ||
preln, | ||
fp16=True): | ||
for name, child in model.named_children(): | ||
if isinstance(child, layer_obj): | ||
print('REPLACING BertLayer') | ||
|
||
cuda_config = DeepSpeedTransformerConfig( | ||
batch_size=micro_batch_size, | ||
max_seq_length=max_seq_length, | ||
hidden_size=config.hidden_size, | ||
heads=config.num_attention_heads, | ||
attn_dropout_ratio=config.attention_probs_dropout_prob, | ||
hidden_dropout_ratio=config.hidden_dropout_prob, | ||
num_hidden_layers=config.num_hidden_layers, | ||
initializer_range=config.initializer_range, | ||
seed=seed, | ||
fp16=fp16, | ||
pre_layer_norm=preln) | ||
|
||
new_module = DeepSpeedTransformerLayer(cuda_config) | ||
|
||
# copy relevant state from child -> new module | ||
qw = child.attention.self.query.weight | ||
qb = child.attention.self.query.bias | ||
kw = child.attention.self.key.weight | ||
kb = child.attention.self.key.bias | ||
vw = child.attention.self.value.weight | ||
vb = child.attention.self.value.bias | ||
|
||
qkvw = torch.cat((qw, kw, vw), 0) | ||
qkvb = torch.cat((qb, kb, vb), 0) | ||
|
||
new_module.attn_qkvw.data = qkvw | ||
new_module.attn_qkvb.data = qkvb | ||
new_module.attn_ow.data = child.attention.output.dense.weight | ||
new_module.attn_ob.data = child.attention.output.dense.bias | ||
if preln: | ||
attention_layerNorm = child.PostAttentionLayerNorm | ||
else: | ||
attention_layerNorm = child.attention.output.LayerNorm | ||
new_module.attn_nw.data = attention_layerNorm.weight | ||
new_module.attn_nb.data = attention_layerNorm.bias | ||
if preln: | ||
intermediate_FF = child.intermediate.dense_act | ||
else: | ||
intermediate_FF = child.intermediate.dense | ||
new_module.inter_w.data = intermediate_FF.weight | ||
new_module.inter_b.data = intermediate_FF.bias | ||
new_module.output_w.data = child.output.dense.weight | ||
new_module.output_b.data = child.output.dense.bias | ||
if preln: | ||
transformer_LayerNorm = child.PreAttentionLayerNorm | ||
else: | ||
transformer_LayerNorm = child.output.LayerNorm | ||
new_module.norm_w.data = transformer_LayerNorm.weight | ||
new_module.norm_b.data = transformer_LayerNorm.bias | ||
|
||
setattr(model, name, copy.deepcopy(new_module)) | ||
|
||
else: | ||
module_inject(layer_obj, | ||
child, | ||
config, | ||
micro_batch_size, | ||
max_seq_length, | ||
seed, | ||
preln, | ||
fp16) | ||
|
||
return model | ||
|
||
|
||
def test_hi(): | ||
from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN | ||
from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN | ||
from turing.nvidia_modelingpreln import BertLayer | ||
bert_model_config = { | ||
"vocab_size_or_config_json_file": 119547, | ||
"hidden_size": 1024, | ||
"num_hidden_layers": 1, | ||
"num_attention_heads": 16, | ||
"intermediate_size": 4096, | ||
"hidden_act": "gelu", | ||
"hidden_dropout_prob": 0.1, | ||
"attention_probs_dropout_prob": 0.1, | ||
"hidden_dropout_prob": 0.1, | ||
"attention_probs_dropout_prob": 0.1, | ||
"max_position_embeddings": 512, | ||
"type_vocab_size": 2, | ||
"initializer_range": 0.02 | ||
} | ||
bert_config = BertConfigPreLN(**bert_model_config) | ||
base_model = BertForQuestionAnsweringPreLN(bert_config, args=None) | ||
|
||
#base_model = LinearStack() | ||
|
||
test_model = copy.deepcopy(base_model) | ||
test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234) | ||
|
||
print('BASE', base_model) | ||
print('TEST', test_model) | ||
|
||
#base_model.eval() | ||
#test_model.eval() | ||
|
||
#test_input = torch.rand(1, base_model.input_dim) | ||
|
||
#base_output = base_model(test_input) | ||
#test_output = test_model(test_input) | ||
# | ||
#assert torch.allclose(base_output, test_output, atol=3e-8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import copy | ||
import torch | ||
import deepspeed | ||
|
||
|
||
def replace_transformer_layer(orig_layer_impl, | ||
model, | ||
micro_batch_size, | ||
bert_config, | ||
seed, | ||
max_seq_length, | ||
preln=False, | ||
fp16=True, | ||
huggingface=False, | ||
local_rank=-1): | ||
""" Replace bert-style transformer layers with DeepSpeed's transformer layer | ||
Arguments: | ||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, | ||
e.g., transformers.modeling_bert.BertLayer. | ||
model (torch.nn.Module): user's nn.module representing their model | ||
micro_batch_size (int): micro batch size per gpu used during training/eval | ||
bert_config (dict): model config containing hidden size, attention heads, etc. | ||
seed (int): random seed value | ||
max_seq_length (int): max sequence length for training | ||
preln (bool): does the original layer implementation do pre or post layer norm? | ||
fp16 (bool): fp16 or fp32 | ||
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes) | ||
Returns: | ||
Updated nn.module with replaced transformer layers | ||
""" | ||
def replace_fn(child): | ||
transformer_config = deepspeed.DeepSpeedTransformerConfig( | ||
batch_size=micro_batch_size, | ||
max_seq_length=max_seq_length, | ||
hidden_size=bert_config.hidden_size, | ||
heads=bert_config.num_attention_heads, | ||
attn_dropout_ratio=bert_config.attention_probs_dropout_prob, | ||
hidden_dropout_ratio=bert_config.hidden_dropout_prob, | ||
num_hidden_layers=bert_config.num_hidden_layers, | ||
initializer_range=bert_config.initializer_range, | ||
seed=seed, | ||
fp16=fp16, | ||
pre_layer_norm=preln, | ||
huggingface=huggingface, | ||
local_rank=local_rank) | ||
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) | ||
|
||
# copy relevant state from child -> new module | ||
qw = child.attention.self.query.weight | ||
qb = child.attention.self.query.bias | ||
kw = child.attention.self.key.weight | ||
kb = child.attention.self.key.bias | ||
vw = child.attention.self.value.weight | ||
vb = child.attention.self.value.bias | ||
|
||
qkvw = torch.cat((qw, kw, vw), 0) | ||
qkvb = torch.cat((qb, kb, vb), 0) | ||
|
||
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0) | ||
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0) | ||
|
||
new_module.attn_qkvw.data = qkvw | ||
new_module.attn_qkvb.data = qkvb | ||
new_module.attn_ow.data = child.attention.output.dense.weight | ||
new_module.attn_ob.data = child.attention.output.dense.bias | ||
if preln: | ||
attention_layernorm = child.PostAttentionLayerNorm | ||
else: | ||
attention_layernorm = child.attention.output.LayerNorm | ||
new_module.attn_nw.data = attention_layernorm.weight | ||
new_module.attn_nb.data = attention_layernorm.bias | ||
if preln: | ||
intermediate_ff = child.intermediate.dense_act | ||
else: | ||
intermediate_ff = child.intermediate.dense | ||
new_module.inter_w.data = intermediate_ff.weight | ||
new_module.inter_b.data = intermediate_ff.bias | ||
new_module.output_w.data = child.output.dense.weight | ||
new_module.output_b.data = child.output.dense.bias | ||
if preln: | ||
transformer_layernorm = child.PreAttentionLayerNorm | ||
else: | ||
transformer_layernorm = child.output.LayerNorm | ||
new_module.norm_w.data = transformer_layernorm.weight | ||
new_module.norm_b.data = transformer_layernorm.bias | ||
return new_module | ||
|
||
return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn) | ||
|
||
|
||
def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False): | ||
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer | ||
Arguments: | ||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced, | ||
e.g., transformers.modeling_bert.BertLayer. | ||
model (torch.nn.Module): user's nn.module representing their model | ||
bert_config (dict): model config containing hidden size, attention heads, etc. | ||
Returns: | ||
Updated nn.module with original bert-style transformer layers | ||
""" | ||
def replace_fn(child): | ||
#from turing.nvidia_modelingpreln import BertLayer | ||
orig_module = orig_layer_impl(bert_config) | ||
|
||
# copy relevant state from child -> original module | ||
qkvw = child.attn_qkvw.data | ||
qkvb = child.attn_qkvb.data | ||
|
||
qw, kw, vw = torch.chunk(qkvw, 3, axis=0) | ||
qb, kb, vb = torch.chunk(qkvb, 3, axis=0) | ||
|
||
orig_module.attention.self.query.weight.data = qw | ||
orig_module.attention.self.query.bias.data = qb | ||
orig_module.attention.self.key.weight.data = kw | ||
orig_module.attention.self.key.bias.data = kb | ||
orig_module.attention.self.value.weight.data = vw | ||
orig_module.attention.self.value.bias.data = vb | ||
|
||
orig_module.attention.output.dense.weight.data = child.attn_ow.data | ||
orig_module.attention.output.dense.bias.data = child.attn_ob.data | ||
|
||
attn_ln_w = child.attn_nw.data | ||
attn_ln_b = child.attn_nb.data | ||
if preln: | ||
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w | ||
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b | ||
else: | ||
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w | ||
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b | ||
|
||
inter_ff_w = child.inter_w.data | ||
inter_ff_b = child.inter_b.data | ||
if preln: | ||
orig_module.intermediate.dense_act.weight.data = inter_ff_w | ||
orig_module.intermediate.dense_act.bias.data = inter_ff_b | ||
else: | ||
orig_module.intermediate.dense.weight.data = inter_ff_w | ||
orig_module.intermediate.dense.bias.data = inter_ff_b | ||
|
||
orig_module.output.dense.weight.data = child.output_w.data | ||
orig_module.output.dense.bias.data = child.output_b.data | ||
|
||
transformer_ln_w = child.norm_w.data | ||
transformer_ln_b = child.norm_b.data | ||
if preln: | ||
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w | ||
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b | ||
else: | ||
orig_module.output.LayerNorm.weight.data = transformer_ln_w | ||
orig_module.output.LayerNorm.bias.data = transformer_ln_b | ||
return orig_module | ||
|
||
return replace_module(model=model, | ||
orig_class=deepspeed.DeepSpeedTransformerLayer, | ||
replace_fn=replace_fn) | ||
|
||
|
||
def replace_module(model, orig_class, replace_fn): | ||
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``. | ||
Arguments: | ||
model (torch.nn.Module): the model to augment | ||
orig_class (torch.nn.Module): the module to search for | ||
replace_fn (method): a method to convert instances of ``orig_class`` to the | ||
desired type and return a new instance. | ||
Returns: | ||
A modified ``model``. | ||
""" | ||
policy = {orig_class: replace_fn} | ||
return _replace_module(model, policy) | ||
|
||
|
||
def _replace_module(model, policies): | ||
""" Traverse model's children recursively and apply any transformations in ``policies``. | ||
Arguments: | ||
model (torch.nn.Module): model to augment | ||
policies (dict): Mapping of source class to replacement function. | ||
Returns: | ||
Modified ``model``. | ||
""" | ||
for name, child in model.named_children(): | ||
if child.__class__ in policies: | ||
orig = repr(child) | ||
setattr(model, name, policies[child.__class__](child)) | ||
new = getattr(model, name) | ||
else: | ||
_replace_module(child, policies) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.