You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I also met this error. Actually, get_all_params() will return the compressed qkv matrices, which have shape (6144, 128). But lora_params will contain the individual lora parameters. See the following:
I also met this error. Actually, get_all_params() will return the compressed qkv matrices, which have shape (6144, 128). But lora_params will contain the individual lora parameters. See the following:
I understand the reasons you mentioned, but I still don't grasp how to fix this bug specifically. Could you please explain in more detail or provide the corresponding code modifications? Thank you very much!
Describe the bug
Step 3
/data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py: │
│ 523 in │
│ │
│ 520 │
│ 521 │
│ 522 if name == "main": │
│ ❱ 523 │ main() │
│ 524 │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py: │
│ 430 in main │
│ │
│ 427 │ │ │ # prompts = prompts[:, length - args.max_prompt_seq_len:] │
│ 428 │ │ │ # raise ValueError("Prompt length is too long") │
│ 429 │ │ │ │
│ ❱ 430 │ │ │ out = trainer.generate_experience(batch_prompt['prompt'], │
│ 431 │ │ │ │ │ │ │ │ │ │ │ batch_prompt['prompt_att_mask']) │
│ 432 │ │ │ │
│ 433 │ │ │ exp_dataset = exp_mini_dataset.add(out) │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trai │
│ ner.py:99 in generate_experience │
│ │
│ 96 │ │
│ 97 │ def generate_experience(self, prompts, mask): │
│ 98 │ │ self.eval() │
│ ❱ 99 │ │ seq = self._generate_sequence(prompts, mask) │
│ 100 │ │ self.train() │
│ 101 │ │ │
│ 102 │ │ pad_token_id = self.tokenizer.pad_token_id │
│ │
│ /data/code/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trai │
│ ner.py:74 in _generate_sequence │
│ │
│ 71 │ │ print(f'prompts: {prompts.shape}, mask: {mask.shape}') │
│ 72 │ │ print(f'max_min_length: {max_min_length}') │
│ 73 │ │ with torch.no_grad(): │
│ ❱ 74 │ │ │ seq = self.actor_model.module.generate(prompts, │
│ 75 │ │ │ │ │ │ │ │ │ │ │ │ attention_mask=mask, │
│ 76 │ │ │ │ │ │ │ │ │ │ │ │ max_length=max_min_length, │
│ 77 │ │ │ │ │ │ │ │ │ │ │ │ min_length=max_min_length) │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:263 in generate │
│ │
│ 260 │ │ │ │ │ │ self.unfuse_lora_weight() │
│ 261 │ │ else: │
│ 262 │ │ │ if len(self.all_lora_params) > 0 and (not self.Z3_enabled): │
│ ❱ 263 │ │ │ │ self.fuse_lora_weight() │
│ 264 │ │ │ │
│ 265 │ │ │ self.retake_inference_cache() │
│ 266 │ │ │ generate_ret_vals = self._generate(*inputs, **kwargs) │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:141 in fuse_lora_weight │
│ │
│ 138 │ │
│ 139 │ def fuse_lora_weight(self): │
│ 140 │ │ for layer_id in range(len(self.layer_params)): │
│ ❱ 141 │ │ │ self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) │
│ 142 │ │
│ 143 │ def _unfuse_lora(self, params, lora_params): │
│ 144 │ │ maybe_has_lora_params = [p for p in params if len(p.shape) > 1] │
│ │
│ /root/anaconda3/envs/chinese_vicuna_env/lib/python3.8/site-packages/deepspeed/runtime/hybrid_eng │
│ ine.py:137 in _fuse_lora │
│ │
│ 134 │ │ │ │ lora_right_weight, \ │
│ 135 │ │ │ │ lora_left_weight, \ │
│ 136 │ │ │ │ lora_scaling = lora_param │
│ ❱ 137 │ │ │ │ weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_ri │
│ 138 │ │
│ 139 │ def fuse_lora_weight(self): │
│ 140 │ │ for layer_id in range(len(self.layer_params)): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: The size of tensor a (6144) must match the size of tensor b (8192) at non-singleton dimension 0
The text was updated successfully, but these errors were encountered: