Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV] Final fix RWMV 4bit #26134

Merged
merged 4 commits into from
Sep 13, 2023
Merged

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Fixes #23848

Double quantization was not working properly for RWKV models as stated in the PR above - leading to an error. This PR proposed a global fix for RWKV models so that they can be ran in 4bit bitsandbytes without any problem.

The followed approach here is the following:

  • For each target layer, de-quantize the 4bit weights using bnb.nn.functional.dequantize_4bit
  • Perform the weights scaling
  • Requantize the weights again.

That way it is possible to make sure to cover both the double quantization and classic 4bit quantization and match the results together

import torch
from transformers import RwkvForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "RWKV/rwkv-4-169m-pile"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True
)

model = RwkvForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)

text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)

out = model.generate(input_ids, max_new_tokens=30)
print(tok.decode(out[0], skip_special_tokens=True))

model_non_dequant = RwkvForCausalLM.from_pretrained(model_id, load_in_4bit=True)

text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)

out = model.generate(input_ids, max_new_tokens=30)
print(tok.decode(out[0], skip_special_tokens=True))

cc @amyeroberts and @SunMarc for your information!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this fix!

Could you add a test which would have failed before this change and passes now? Ideally it should be applied to all eligible models

# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting requires_grad=False here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a requirement from bnb, all quantized parameters need to have that value set to False whereas the default is True :/ I can open an issue on bnb if this is a bug

Otherwise you get

  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 179, in to
    return self.cuda(device)
  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 158, in cuda
    self.data = w_4bit
RuntimeError: data set to a tensor that requires gradients must be floating point or complex dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, even if it's not a bug, then it would be good to get clarification about this.

Will this affectively set these layers to non-trainable even if they were trainable before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I am not sure here, in any case, quantized layers cannot be trained in any case as this is not supported
I have added more clarifications here: ba1b10f

@younesbelkada
Copy link
Contributor Author

Thanks for the review, I added a test that should be applicable to other checkpoints as well, as they use the same arch

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and iterating!

@younesbelkada younesbelkada merged commit 7ccac73 into huggingface:main Sep 13, 2023
@younesbelkada younesbelkada deleted the fix-rwkv-4bit branch September 13, 2023 14:30
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RWKV - Inference NF4 quantization broken, also Int8 quantization weirdness.
3 participants