Skip to content

Commit

Permalink
Fix gpt-j inference issue (#3639)
Browse files Browse the repository at this point in the history
* fix gpt-j inference issue for mlp_gemm_func call

* bring back the gpt-j inference-test

* fix formatting

* fix the neox and pythia injection issue
  • Loading branch information
RezaYazdaniAminabadi authored Jun 7, 2023
1 parent 7e59ef1 commit 34a9fbf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 14 additions & 10 deletions deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,21 @@ def mlp_output_mp(self, mp_replace, reversed_dim=False):
allocate_tensor=reversed_dim)

def copy_data_to_new_module(self):
params = {
self.module.mlp.attn_nw: self.attn_nw,
self.module.mlp.attn_nb: self.attn_nb,
self.module.norm_w: self.input_nw,
self.module.norm_b: self.input_nb
}
for dst, src in params.items():
if src is None:
dst = src
params = {'attn_nw': self.attn_nw, 'attn_nb': self.attn_nb}
for key in params:
if params[key] is None:
setattr(self.module.mlp, key, None)
else:
dst.data.copy_(src.to(get_accelerator().current_device_name()))
setattr(self.module.mlp, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))

params = {'norm_w': self.input_nw, 'norm_b': self.input_nb}
for key in params:
if params[key] is None:
setattr(self.module, key, None)
else:
setattr(self.module, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))

def transpose(self):
self.transpose_attention()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"gpt2",
"distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
#"EleutherAI/gpt-j-6B", # Removed as this is causing OOM errors randomly
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later!
"bigscience/bloom-560m",
]
_opt_models = [
Expand Down

0 comments on commit 34a9fbf

Please sign in to comment.