Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV] Final fix RWMV 4bit #26134

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_bitsandbytes_available,
is_ninja_available,
is_torch_cuda_available,
logging,
Expand Down Expand Up @@ -735,18 +736,33 @@ def _rescale_layers(self):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
elif hasattr(block.attention.output.weight, "quant_state"):
block.attention.output.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
block.feed_forward.value.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))

self.layers_are_rescaled = not self.training

def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
r"""
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
be quantized again.
"""
if not is_bitsandbytes_available():
raise ImportError("Please install bitsandbytes to use this method.")
import bitsandbytes as bnb

dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)

dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))

# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting requires_grad=False here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a requirement from bnb, all quantized parameters need to have that value set to False whereas the default is True :/ I can open an issue on bnb if this is a bug

Otherwise you get

  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 179, in to
    return self.cuda(device)
  File "/home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 158, in cuda
    self.data = w_4bit
RuntimeError: data set to a tensor that requires gradients must be floating point or complex dtype

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, even if it's not a bug, then it would be good to get clarification about this.

Will this affectively set these layers to non-trainable even if they were trainable before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I am not sure here, in any case, quantized layers cannot be trained in any case as this is not supported
I have added more clarifications here: ba1b10f

setattr(target_layer, "weight", quant_weight)


@add_start_docstrings(
"""
Expand Down