Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bnb] Add fp4 support for dispatch #1505

Merged
merged 3 commits into from
Jun 1, 2023
Merged

[bnb] Add fp4 support for dispatch #1505

merged 3 commits into from
Jun 1, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Fixes #1504

This PR applies the similar enhancement as #1228 for FP4 layers

Now the script below outputs the desired dtype:

import torch
from transformers import AutoConfig, AutoModel, BitsAndBytesConfig
from transformers.utils.bitsandbytes import replace_8bit_linear

from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download

from accelerate import load_checkpoint_and_dispatch


with init_empty_weights():
    model = AutoModel.from_config(AutoConfig.from_pretrained("bigscience/bloom-560m"))

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True
)

model = replace_8bit_linear(model, quantization_config=quantization_config)

# For some reason replace_8bit_linear creates parameters with requires_grad=True but it's irrelevant rn
for p in model.parameters():
    p.requires_grad = False

model_path = hf_hub_download("bigscience/bloom-560m", "pytorch_model.bin")

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=model_path,
    device_map="auto",
)

print(f"{model.h[0].self_attention.query_key_value.weight.device}\n{model.h[0].self_attention.query_key_value.weight.dtype}")
>>> torch.uint8

cc @sgugger @BlackSamorez

# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .cuda function is very very deprecated. You should use to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think the way bitsandbytes has designed its Linear4bit layers we need to call cuda: https://github.com/TimDettmers/bitsandbytes/blob/ac5550a0238286377ee3f58a85aeba1c40493e17/bitsandbytes/nn/modules.py#L152 it seems to be the only way to quantize the weights :/ I tried it with to and it didn't worked. (note that at that point module.weight is a bnb.nn.Params4bit module)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok. Not very PyTorch-ic then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah ! :/

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 1, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada requested a review from sgugger June 1, 2023 15:11
"""Tests that `dispatch_model` quantizes int8 layers"""
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModel, BitsAndBytesConfig
from transformers.utils.bitsandbytes import replace_8bit_linear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this function renamed to something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, let me modify that

@slow
@unittest.skip("Un-skip in the next transformers release")
def test_dipatch_model_fp4_simple(self):
"""Tests that `dispatch_model` quantizes int8 layers"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To adapt.

@younesbelkada younesbelkada requested a review from sgugger June 1, 2023 15:17
@younesbelkada younesbelkada merged commit 8ae56dc into main Jun 1, 2023
@younesbelkada younesbelkada deleted the add-dispatch-4bit branch June 1, 2023 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

set_module_tensor_to_device doesn't properly deploy BitsAndBytes Linear4bit
3 participants