Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV] Final fix RWMV 4bit #26134

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_bitsandbytes_available,
is_ninja_available,
is_torch_cuda_available,
logging,
Expand Down Expand Up @@ -735,18 +736,35 @@ def _rescale_layers(self):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
elif hasattr(block.attention.output.weight, "quant_state"):
block.attention.output.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
block.feed_forward.value.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))

self.layers_are_rescaled = not self.training

def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
r"""
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
be quantized again.
"""
if not is_bitsandbytes_available():
raise ImportError("Please install bitsandbytes to use this method.")
import bitsandbytes as bnb

dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)

dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))

# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
# bugs with bnb
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting requires_grad=False here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a requirement from bnb, all quantized parameters need to have that value set to False whereas the default is True :/ I can open an issue on bnb if this is a bug

Otherwise you get

  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 179, in to
    return self.cuda(device)
  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 158, in cuda
    self.data = w_4bit
RuntimeError: data set to a tensor that requires gradients must be floating point or complex dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, even if it's not a bug, then it would be good to get clarification about this.

Will this affectively set these layers to non-trainable even if they were trainable before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I am not sure here, in any case, quantized layers cannot be trained in any case as this is not supported
I have added more clarifications here: ba1b10f

setattr(target_layer, "weight", quant_weight)


@add_start_docstrings(
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ def test_linear_are_4bit(self):
# 4-bit parameters are packed in uint8 variables
self.assertTrue(module.weight.dtype == torch.uint8)

def test_rwkv_4bit(self):
r"""
A simple test to check if 4-bit RWKV inference works as expected.
"""
model_id = "RWKV/rwkv-4-169m-pile"

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)

text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)

_ = model.generate(input_ids, max_new_tokens=30)

def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Expand Down