Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]RuntimeError: The size of tensor a (6144) must match the size of tensor b (8192) at non-singleton dimension 0 #3543

Closed
hulonghua-devin opened this issue May 15, 2023 · 3 comments · Fixed by #3563
Assignees
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat training

Comments

@hulonghua-devin
Copy link

Describe the bug
Step 3

/data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py: │
│ 523 in │
│ │
│ 520 │
│ 521 │
│ 522 if name == "main": │
│ ❱ 523 │ main() │
│ 524 │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py: │
│ 430 in main │
│ │
│ 427 │ │ │ # prompts = prompts[:, length - args.max_prompt_seq_len:] │
│ 428 │ │ │ # raise ValueError("Prompt length is too long") │
│ 429 │ │ │ │
│ ❱ 430 │ │ │ out = trainer.generate_experience(batch_prompt['prompt'], │
│ 431 │ │ │ │ │ │ │ │ │ │ │ batch_prompt['prompt_att_mask']) │
│ 432 │ │ │ │
│ 433 │ │ │ exp_dataset = exp_mini_dataset.add(out) │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trai │
│ ner.py:99 in generate_experience │
│ │
│ 96 │ │
│ 97 │ def generate_experience(self, prompts, mask): │
│ 98 │ │ self.eval() │
│ ❱ 99 │ │ seq = self._generate_sequence(prompts, mask) │
│ 100 │ │ self.train() │
│ 101 │ │ │
│ 102 │ │ pad_token_id = self.tokenizer.pad_token_id │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trai │
│ ner.py:74 in _generate_sequence │
│ │
│ 71 │ │ print(f'prompts: {prompts.shape}, mask: {mask.shape}') │
│ 72 │ │ print(f'max_min_length: {max_min_length}') │
│ 73 │ │ with torch.no_grad(): │
│ ❱ 74 │ │ │ seq = self.actor_model.module.generate(prompts, │
│ 75 │ │ │ │ │ │ │ │ │ │ │ │ attention_mask=mask, │
│ 76 │ │ │ │ │ │ │ │ │ │ │ │ max_length=max_min_length, │
│ 77 │ │ │ │ │ │ │ │ │ │ │ │ min_length=max_min_length) │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:263 in generate │
│ │
│ 260 │ │ │ │ │ │ self.unfuse_lora_weight() │
│ 261 │ │ else: │
│ 262 │ │ │ if len(self.all_lora_params) > 0 and (not self.Z3_enabled): │
│ ❱ 263 │ │ │ │ self.fuse_lora_weight() │
│ 264 │ │ │ │
│ 265 │ │ │ self.retake_inference_cache() │
│ 266 │ │ │ generate_ret_vals = self._generate(*inputs, **kwargs) │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:141 in fuse_lora_weight │
│ │
│ 138 │ │
│ 139 │ def fuse_lora_weight(self): │
│ 140 │ │ for layer_id in range(len(self.layer_params)): │
│ ❱ 141 │ │ │ self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) │
│ 142 │ │
│ 143 │ def _unfuse_lora(self, params, lora_params): │
│ 144 │ │ maybe_has_lora_params = [p for p in params if len(p.shape) > 1] │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:137 in _fuse_lora │
│ │
│ 134 │ │ │ │ lora_right_weight, \ │
│ 135 │ │ │ │ lora_left_weight, \ │
│ 136 │ │ │ │ lora_scaling = lora_param │
│ ❱ 137 │ │ │ │ weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_ri │
│ 138 │ │
│ 139 │ def fuse_lora_weight(self): │
│ 140 │ │ for layer_id in range(len(self.layer_params)): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: The size of tensor a (6144) must match the size of tensor b (8192) at non-singleton dimension 0

@hulonghua-devin hulonghua-devin added bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat labels May 15, 2023
@sxjscience
Copy link
Contributor

I also met this error. Actually, get_all_params() will return the compressed qkv matrices, which have shape (6144, 128). But lora_params will contain the individual lora parameters. See the following:

The way to fix it seems to be implement fuse_lora and unfuse_lora in each of the container policy, e.g., adding fuse_lora and unfuse_lora methods in https://github.com/microsoft/DeepSpeed/blob/9685eb92ab98ea4534fbfce21e303f396575f7e4/deepspeed/module_inject/containers/opt.py#L19

@hulonghua-devin
Copy link
Author

I also met this error. Actually, get_all_params() will return the compressed qkv matrices, which have shape (6144, 128). But lora_params will contain the individual lora parameters. See the following:

The way to fix it seems to be implement fuse_lora and unfuse_lora in each of the container policy, e.g., adding fuse_lora and unfuse_lora methods in

https://github.com/microsoft/DeepSpeed/blob/9685eb92ab98ea4534fbfce21e303f396575f7e4/deepspeed/module_inject/containers/opt.py#L19

I understand the reasons you mentioned, but I still don't grasp how to fix this bug specifically. Could you please explain in more detail or provide the corresponding code modifications? Thank you very much!

@sxjscience
Copy link
Contributor

I have a branch for the potential fix: master...sxjscience:DeepSpeed:fix_lora_hybrid_engine . However, I'm still verifying it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working deepspeed-chat Related to DeepSpeed-Chat training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants