diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py index a0dc3e98de35..a520664793ca 100644 --- a/deepspeed/module_inject/containers/base.py +++ b/deepspeed/module_inject/containers/base.py @@ -249,17 +249,21 @@ def mlp_output_mp(self, mp_replace, reversed_dim=False): allocate_tensor=reversed_dim) def copy_data_to_new_module(self): - params = { - self.module.mlp.attn_nw: self.attn_nw, - self.module.mlp.attn_nb: self.attn_nb, - self.module.norm_w: self.input_nw, - self.module.norm_b: self.input_nb - } - for dst, src in params.items(): - if src is None: - dst = src + params = {'attn_nw': self.attn_nw, 'attn_nb': self.attn_nb} + for key in params: + if params[key] is None: + setattr(self.module.mlp, key, None) else: - dst.data.copy_(src.to(get_accelerator().current_device_name())) + setattr(self.module.mlp, key, + torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name()))) + + params = {'norm_w': self.input_nw, 'norm_b': self.input_nb} + for key in params: + if params[key] is None: + setattr(self.module, key, None) + else: + setattr(self.module, key, + torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name()))) def transpose(self): self.transpose_attention() diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 65365837a0b8..c42deb3dd6d7 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -49,7 +49,7 @@ "gpt2", "distilgpt2", "Norod78/hebrew-bad_wiki-gpt_neo-tiny", - #"EleutherAI/gpt-j-6B", # Removed as this is causing OOM errors randomly + "EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later! "bigscience/bloom-560m", ] _opt_models = [