-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes for quant_storage and CPU offloading #1279
Fixes for quant_storage and CPU offloading #1279
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Left a comment to better understand !
if not isinstance(self.weight, Params4bit): | ||
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) | ||
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this related to cpu offload ? The weights are of instance Params4bit
no ? Also, why do we need to hardcode bnb_quantized
to True ? Is it to make sure that we don't quantize the offloaded weights when we dispatch from CPU to GPU ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is to make sure we don't try to quantize again. Without that, Params4bit.to()
will try to quantize, and we see a "ValueError: Blockwise quantization only supports 16/32-bit floats" raised. Normally Params4bit._quantize()
will set self.bnb_quantized = True
when done (and likewise Params4bit.from_prequantized()
does the same), so I am emulating that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect ! Make sure to also add a comment. It's a bit hacky and hopefully, no one complains about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdouglas do you have a minimal reproducible example to reproduce the "ValueError: Blockwise quantization only supports 16/32-bit floats"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be useful also from the cpu offloading perspective.. The thing is that the serialization with different dtype is actually fully supported imo, I even just ran some tests verifying that.. Let me dig in some more..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the goal for me is to understand the mechanism of what's going wrong in detail under the hood in both the lit-gpt and cpu offloading cases, for me that's still very unclear. It would be great to reduce it to the essence to look at it purely on the BNB side, without the dependencies and with the simplest possible reproducible example.
Matthew told me he would look into this and we'll discuss / solve more together tmr. Thanks @matthewdouglas for taking the lead on this 🚀🙌🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Titus-von-Koeller If you apply the PR on accelerate: huggingface/accelerate#2934 you can reproduce this on bitsandbytes main. Without the accelerate PR you'll see ValueError: Trying to set a tensor of shape torch.Size([2048, 2048]) in "weight" (which has shape torch.Size([2097152, 1])), this look incorrect.
before we try to use quantize_4bit
from bnb.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_ID = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_quant_storage=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
device_map="auto",
max_memory={0: "0.5GiB", "cpu": "8GiB"},
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
print(model, model.hf_device_map)
inputs = tokenizer("What is the meaning of life, the universe, and everything?", return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)
print(f"{tokenizer.decode(output[0])}")
Output:
Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
(model): OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(50272, 2048, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(layers): ModuleList(
(0-23): 24 x OPTDecoderLayer(
(self_attn): OPTAttention(
(k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
)
(activation_fn): ReLU()
(self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
(fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
Traceback (most recent call last):
File "/home/matt/code/accelerate/./sandbox/inference.py", line 30, in <module>
output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 1914, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2651, in _sample
outputs = self(
^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 1118, in forward
outputs = self.model.decoder(
^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 884, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 525, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 155, in forward
query_states = self.q_proj(hidden_states) * self.scaling
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 164, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 354, in pre_forward
set_module_tensor_to_device(
File "/home/matt/code/accelerate/src/accelerate/utils/modeling.py", line 426, in set_module_tensor_to_device
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/bitsandbytes/bitsandbytes/nn/modules.py", line 324, in to
return self._quantize(device)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/bitsandbytes/bitsandbytes/nn/modules.py", line 289, in _quantize
w_4bit, quant_state = bnb.functional.quantize_4bit(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/matt/code/bitsandbytes/bitsandbytes/functional.py", line 1238, in quantize_4bit
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8
With both the accelerate and this PR, the output is as expected:
Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
(model): OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(50272, 2048, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(layers): ModuleList(
(0-23): 24 x OPTDecoderLayer(
(self_attn): OPTAttention(
(k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
(out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
)
(activation_fn): ReLU()
(self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
(fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
</s>What is the meaning of life, the universe, and everything?
I think it's a reference to the song "What is the meaning of
I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the
In this example, the weights retain the right dtype (the one that holds the packed quantized weights) despite serialization. Let me know if I misunderstood anything or if you have any further concers / questions. I'll be sure to dig into this more tmr :) |
You're right, so my comment about that is incorrect. What I notice is that uint8 is the default for |
Ok, after very thorough review I have to say that this is great work. Thanks for cleaning this part of the code up with this more correct and complete logic. Regarding the serialization it just need this small fix to pick up on the quant_storage dtype based on serialized tensor. Test suite is all green despite the usual flakiness (which I double-checked for everything by hand). Will merge this and then trigger the HF integration tests on Really helpful and good work. Thanks @matthewdouglas ❤️ 🤗 |
7fed393
into
bitsandbytes-foundation:main
…ndation#1279) * Fix restoration of quant_storage for CPU offloading * Clarify comment on default quant_storage in Params4bit.from_prequantized() * fix to make quant_storage dynamic based on serialized dtype * delete obsolete comment --------- Co-authored-by: Titus von Koeller <[email protected]>
This change ensures we don't lose track of a non-default
quant_storage
option or quantization state when moving between CPU and GPU.cc: @Titus-von-Koeller @SunMarc