Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle Params4bit in set_module_tensor_to_device #2934

Merged
merged 2 commits into from
Jul 22, 2024

Conversation

matthewdouglas
Copy link
Member

@matthewdouglas matthewdouglas commented Jul 15, 2024

What does this PR do?

This PR fixes compatibility for bitsandbytes Params4bit and set_module_tensor_to_device.

Related: bitsandbytes-foundation/bitsandbytes#1279

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @matthewdouglas !

Comment on lines +364 to 367
if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
raise ValueError(
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this look incorrect.'
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly here, the shape of the Param4bit is different from the actual weight that we are trying to set in the offload case. That happens because in with offload, the weight is not quantized.
Could you add a short comment to explain what happens here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct! The shape is changed and weights are packed (e.g. two nf4/fp4 values in uint8) with Params4bit. I will add a comment to explain that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc cc @muellerzr @Titus-von-Koeller

I've added a comment to explain. Here's a small repro example for the shape mismatch.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


MODEL_ID = "facebook/opt-1.3b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    max_memory={0: "0.5GiB", "cpu": "8GiB"},
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
)

print(model, model.hf_device_map)

inputs = tokenizer("What is the meaning of life, the universe, and everything?", return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)

print(f"{tokenizer.decode(output[0])}")

Output:

Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
Traceback (most recent call last):
  File "/home/matt/code/accelerate/./sandbox/inference.py", line 35, in <module>
    output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
              ^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 1118, in forward
    outputs = self.model.decoder(
              ^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 884, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 525, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 155, in forward
    query_states = self.q_proj(hidden_states) * self.scaling
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 164, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 354, in pre_forward
    set_module_tensor_to_device(
  File "/home/matt/code/accelerate/src/accelerate/utils/modeling.py", line 366, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048, 2048]) in "weight" (which has shape torch.Size([2097152, 1])), this look incorrect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! In a follow PR in transformers, I think we can finally remove the need for llm_int8_enable_fp32_cpu_offload in the 4bit case. In the 8bit case, this will still be required since we would still need to not convert the offloaded layers.

@SunMarc SunMarc requested a review from muellerzr July 16, 2024 11:33
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LG2M as long as Marc's logic is indeed why this is a thing, and a comment explaining that would be necessary.

@muellerzr muellerzr merged commit 2308576 into main Jul 22, 2024
28 checks passed
@muellerzr muellerzr deleted the modeling-4bit-fixes branch July 22, 2024 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants