Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for quant_storage and CPU offloading #1279

Conversation

matthewdouglas
Copy link
Member

This change ensures we don't lose track of a non-default quant_storage option or quantization state when moving between CPU and GPU.

cc: @Titus-von-Koeller @SunMarc

Copy link
Contributor

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Left a comment to better understand !

Comment on lines 454 to +455
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this related to cpu offload ? The weights are of instance Params4bit no ? Also, why do we need to hardcode bnb_quantized to True ? Is it to make sure that we don't quantize the offloaded weights when we dispatch from CPU to GPU ?

Copy link
Member Author

@matthewdouglas matthewdouglas Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is to make sure we don't try to quantize again. Without that, Params4bit.to() will try to quantize, and we see a "ValueError: Blockwise quantization only supports 16/32-bit floats" raised. Normally Params4bit._quantize() will set self.bnb_quantized = True when done (and likewise Params4bit.from_prequantized() does the same), so I am emulating that here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect ! Make sure to also add a comment. It's a bit hacky and hopefully, no one complains about this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas do you have a minimal reproducible example to reproduce the "ValueError: Blockwise quantization only supports 16/32-bit floats"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be useful also from the cpu offloading perspective.. The thing is that the serialization with different dtype is actually fully supported imo, I even just ran some tests verifying that.. Let me dig in some more..

Copy link
Collaborator

@Titus-von-Koeller Titus-von-Koeller Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the goal for me is to understand the mechanism of what's going wrong in detail under the hood in both the lit-gpt and cpu offloading cases, for me that's still very unclear. It would be great to reduce it to the essence to look at it purely on the BNB side, without the dependencies and with the simplest possible reproducible example.

Matthew told me he would look into this and we'll discuss / solve more together tmr. Thanks @matthewdouglas for taking the lead on this 🚀🙌🏻

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Titus-von-Koeller If you apply the PR on accelerate: huggingface/accelerate#2934 you can reproduce this on bitsandbytes main. Without the accelerate PR you'll see ValueError: Trying to set a tensor of shape torch.Size([2048, 2048]) in "weight" (which has shape torch.Size([2097152, 1])), this look incorrect. before we try to use quantize_4bit from bnb.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


MODEL_ID = "facebook/opt-1.3b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
    bnb_4bit_quant_storage=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    max_memory={0: "0.5GiB", "cpu": "8GiB"},
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
)

print(model, model.hf_device_map)

inputs = tokenizer("What is the meaning of life, the universe, and everything?", return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)

print(f"{tokenizer.decode(output[0])}")

Output:

Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
Traceback (most recent call last):
  File "/home/matt/code/accelerate/./sandbox/inference.py", line 30, in <module>
    output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
              ^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 1118, in forward
    outputs = self.model.decoder(
              ^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 884, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 525, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 155, in forward
    query_states = self.q_proj(hidden_states) * self.scaling
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 164, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 354, in pre_forward
    set_module_tensor_to_device(
  File "/home/matt/code/accelerate/src/accelerate/utils/modeling.py", line 426, in set_module_tensor_to_device
    new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/bitsandbytes/bitsandbytes/nn/modules.py", line 324, in to
    return self._quantize(device)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/bitsandbytes/bitsandbytes/nn/modules.py", line 289, in _quantize
    w_4bit, quant_state = bnb.functional.quantize_4bit(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/bitsandbytes/bitsandbytes/functional.py", line 1238, in quantize_4bit
    raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8

With both the accelerate and this PR, the output is as expected:

Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
</s>What is the meaning of life, the universe, and everything?
I think it's a reference to the song "What is the meaning of

@Titus-von-Koeller
Copy link
Collaborator

Titus-von-Koeller commented Jul 16, 2024

I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the quant_storage dtype is actually supported for serialization in BNB, so we gotta take this into account:

In [1]: import torch
   ...: import bitsandbytes as bnb
   ...: import io
   ...: 
   ...: def save_and_load_model(model):
   ...:     buffer = io.BytesIO()
   ...:     torch.save(model.state_dict(), buffer)
   ...:     buffer.seek(0)
   ...:     loaded_state_dict = torch.load(buffer)
   ...:     return loaded_state_dict
   ...: 
   ...: def create_linear4bit(in_features, out_features, quant_storage):
   ...:     layer = bnb.nn.Linear4bit(in_features, out_features, quant_storage=quant_storage)
   ...:     layer.weight.data.normal_(0, 1)
   ...:     return layer.cuda()  # Move to CUDA to trigger quantization
   ...: 
   ...: def check_quant_storage(state_dict):
   ...:     weight_state = {k: v for k, v in state_dict.items() if k.startswith('weight')}
   ...:     print("Quantization state keys:", weight_state.keys())
   ...:     for k, v in weight_state.items():
   ...:         if isinstance(v, torch.Tensor):
   ...:             print(f"{k} dtype: {v.dtype}")
   ...: 
In [2]: print("Testing with float16 quant_storage:")
   ...: model_float16 = create_linear4bit(10, 20, quant_storage=torch.float16)
   ...: loaded_state_dict = save_and_load_model(model_float16)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with bfloat16 quant_storage:")
   ...: model_bfloat16 = create_linear4bit(10, 20, quant_storage=torch.bfloat16)
   ...: loaded_state_dict = save_and_load_model(model_bfloat16)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with float32 quant_storage:")
   ...: model_float32 = create_linear4bit(10, 20, quant_storage=torch.float32)
   ...: loaded_state_dict = save_and_load_model(model_float32)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with uint8 quant_storage (default):")
   ...: model_uint8 = create_linear4bit(10, 20, quant_storage=torch.uint8)
   ...: loaded_state_dict = save_and_load_model(model_uint8)
   ...: check_quant_storage(loaded_state_dict)
Testing with float16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with bfloat16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.bfloat16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with float32 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float32
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with uint8 quant_storage (default):
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.uint8
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8

In this example, the weights retain the right dtype (the one that holds the packed quantized weights) despite serialization. quant_state.bitsandbytes__fp4 is just the packed representation of non-tensor quantization state information and shouldn't cause any issues in this context, imo.

Let me know if I misunderstood anything or if you have any further concers / questions.

I'll be sure to dig into this more tmr :)

@Titus-von-Koeller Titus-von-Koeller marked this pull request as draft July 17, 2024 15:53
@matthewdouglas
Copy link
Member Author

I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the quant_storage dtype is actually supported for serialization in BNB, so we gotta take this into account:

You're right, so my comment about that is incorrect. What I notice is that uint8 is the default for Params4bit.__new__ so I think we should keep that the same.

@matthewdouglas matthewdouglas marked this pull request as ready for review July 19, 2024 18:01
@Titus-von-Koeller
Copy link
Collaborator

Ok, after very thorough review I have to say that this is great work. Thanks for cleaning this part of the code up with this more correct and complete logic.

Regarding the serialization it just need this small fix to pick up on the quant_storage dtype based on serialized tensor.

Test suite is all green despite the usual flakiness (which I double-checked for everything by hand).

Will merge this and then trigger the HF integration tests on main one more time and do the release right after. Great team work :D

Really helpful and good work. Thanks @matthewdouglas ❤️ 🤗

@Titus-von-Koeller Titus-von-Koeller merged commit 7fed393 into bitsandbytes-foundation:main Jul 23, 2024
26 checks passed
matthewdouglas added a commit to matthewdouglas/bitsandbytes that referenced this pull request Oct 28, 2024
…ndation#1279)

* Fix restoration of quant_storage for CPU offloading

* Clarify comment on default quant_storage in Params4bit.from_prequantized()

* fix to make quant_storage dynamic based on serialized dtype

* delete obsolete comment

---------

Co-authored-by: Titus von Koeller <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants