Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle Params4bit in set_module_tensor_to_device #2934

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,13 @@ def set_module_tensor_to_device(
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

param = module._parameters[tensor_name] if tensor_name in module._parameters else None
param_cls = type(param)

if value is not None:
if old_value.shape != value.shape:
if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
raise ValueError(
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this look incorrect.'
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
)
Comment on lines +366 to 369
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly here, the shape of the Param4bit is different from the actual weight that we are trying to set in the offload case. That happens because in with offload, the weight is not quantized.
Could you add a short comment to explain what happens here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct! The shape is changed and weights are packed (e.g. two nf4/fp4 values in uint8) with Params4bit. I will add a comment to explain that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc cc @muellerzr @Titus-von-Koeller

I've added a comment to explain. Here's a small repro example for the shape mismatch.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


MODEL_ID = "facebook/opt-1.3b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    max_memory={0: "0.5GiB", "cpu": "8GiB"},
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
)

print(model, model.hf_device_map)

inputs = tokenizer("What is the meaning of life, the universe, and everything?", return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)

print(f"{tokenizer.decode(output[0])}")

Output:

Some parameters are on the meta device device because they were offloaded to the cpu.
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear4bit(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=2048, out_features=50272, bias=False)
) {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 'cpu', 'model.decoder.layers.10': 'cpu', 'model.decoder.layers.11': 'cpu', 'model.decoder.layers.12': 'cpu', 'model.decoder.layers.13': 'cpu', 'model.decoder.layers.14': 'cpu', 'model.decoder.layers.15': 'cpu', 'model.decoder.layers.16': 'cpu', 'model.decoder.layers.17': 'cpu', 'model.decoder.layers.18': 'cpu', 'model.decoder.layers.19': 'cpu', 'model.decoder.layers.20': 'cpu', 'model.decoder.layers.21': 'cpu', 'model.decoder.layers.22': 'cpu', 'model.decoder.layers.23': 'cpu'}
Traceback (most recent call last):
  File "/home/matt/code/accelerate/./sandbox/inference.py", line 35, in <module>
    output = model.generate(**inputs, max_new_tokens=16, num_return_sequences=1)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
              ^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 1118, in forward
    outputs = self.model.decoder(
              ^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 884, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 525, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/transformers/models/opt/modeling_opt.py", line 155, in forward
    query_states = self.q_proj(hidden_states) * self.scaling
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 164, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/matt/code/accelerate/src/accelerate/hooks.py", line 354, in pre_forward
    set_module_tensor_to_device(
  File "/home/matt/code/accelerate/src/accelerate/utils/modeling.py", line 366, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048, 2048]) in "weight" (which has shape torch.Size([2097152, 1])), this look incorrect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! In a follow PR in transformers, I think we can finally remove the need for llm_int8_enable_fp32_cpu_offload in the 4bit case. In the 8bit case, this will still be required since we would still need to not convert the offloaded layers.


if dtype is None:
Expand All @@ -369,9 +372,6 @@ def set_module_tensor_to_device(
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)

param = module._parameters[tensor_name] if tensor_name in module._parameters else None
param_cls = type(param)

device_quantization = None
with torch.no_grad():
# leave it on cpu first before moving them to cuda
Expand Down Expand Up @@ -411,7 +411,7 @@ def set_module_tensor_to_device(
elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
param_cls = type(module._parameters[tensor_name])
kwargs = module._parameters[tensor_name].__dict__
if param_cls.__name__ in ["Int8Params", "FP4Params"]:
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
# downcast to fp16 if any - needed for 8bit serialization
new_value = new_value.to(torch.float16)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_set_module_tensor_checks_shape(self):
set_module_tensor_to_device(model, "linear1.weight", "cpu", value=tensor)
assert (
str(cm.exception)
== 'Trying to set a tensor of shape torch.Size([2, 2]) in "weight" (which has shape torch.Size([4, 3])), this look incorrect.'
== 'Trying to set a tensor of shape torch.Size([2, 2]) in "weight" (which has shape torch.Size([4, 3])), this looks incorrect.'
)

def test_named_tensors(self):
Expand Down
Loading