Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should base model be dequantized when merging LoRA weights with base model? #254

Open
jinyongyoo opened this issue Aug 24, 2023 · 6 comments

Comments

@jinyongyoo
Copy link

Hi, I have a question regarding merging LoRA weights with quantized base model. For cases where we want to merge the LoRA weights back into the original model for inference, we can use merge_and_unload method. However, this obviously isn't possible for the case of quantized base models (as seen in #28).

So a common workaround I've seen is loading the base model (without quantization) and then merging the LoRA weights. But shouldn't this result in a training and inference mismatch since LoRA weights were trained using the quantized model which is different from the base model without quantization (quantization is naturally a noisy process)? I was wondering if such workaround would result in performance degradation of the final model.

Another workaround I can think of is to dequantize the quantized base model and then add the LoRA weight. This would get rid of the training / inference mismatch problem. Has there been any attempt to dequantize the base model and add the LoRA weights?

@eugene-yh
Copy link

There is a mathematical hack to dequantize the base model. See my post here: #28 (comment)

@jinyongyoo
Copy link
Author

Thanks! I ended up using dequantize_4bit method from bnb to dequantize the linear weights, but I think this approach is neat.

@ChrisHayduk
Copy link

@jinyongyoo Would you mind sharing the code that you used to dequantize the model? How did you apply dequantize_4bit?

@jinyongyoo
Copy link
Author

Not sure if this is 100% correct way to do it.

dequantize_4bit(module.weight.data, quant_state=module.weight.quant_state) where module is instance of bnb.nn.Linear4bit. That should get you the weight that you can use to create torch.nn.Linear

@ChrisHayduk
Copy link

@jinyongyoo Awesome, thank you! And you just looped through every module of the model and check if it was of type bnb.nn.Linear4bit and, if it was, you replaced that module with the dequantized version?

@jinyongyoo
Copy link
Author

yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants