Skip to content

Commit

Permalink
[RWKV] Final fix RWMV 4bit (#26134)
Browse files Browse the repository at this point in the history
* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
  • Loading branch information
younesbelkada authored Sep 13, 2023
1 parent 32ec734 commit 7ccac73
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_bitsandbytes_available,
is_ninja_available,
is_torch_cuda_available,
logging,
Expand Down Expand Up @@ -735,18 +736,35 @@ def _rescale_layers(self):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
elif hasattr(block.attention.output.weight, "quant_state"):
block.attention.output.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
block.feed_forward.value.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))

self.layers_are_rescaled = not self.training

def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
r"""
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
be quantized again.
"""
if not is_bitsandbytes_available():
raise ImportError("Please install bitsandbytes to use this method.")
import bitsandbytes as bnb

dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)

dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))

# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
# bugs with bnb
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
setattr(target_layer, "weight", quant_weight)


@add_start_docstrings(
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ def test_linear_are_4bit(self):
# 4-bit parameters are packed in uint8 variables
self.assertTrue(module.weight.dtype == torch.uint8)

def test_rwkv_4bit(self):
r"""
A simple test to check if 4-bit RWKV inference works as expected.
"""
model_id = "RWKV/rwkv-4-169m-pile"

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)

text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)

_ = model.generate(input_ids, max_new_tokens=30)

def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Expand Down

0 comments on commit 7ccac73

Please sign in to comment.