-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AQLM quantizer support #28928
AQLM quantizer support #28928
Conversation
A model to test it: BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch |
A Google Colab demo: Mixtral in 2 bits. |
|
||
if isinstance(module, nn.Linear): | ||
# Check if the current key is not in the `linear_weights_not_to_quantize` | ||
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw in the config of the model you pushed on the Hub that you also included layer norm weights inside linear_weights_not_to_quantize
, I think these can be excluded from the config as they are not an insitance of nn.Linear
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They certainly can be excluded. It's just that converting from a freshly quantized AQLM format it would be troublesome to check if an unquantized .weight
parameter is of nn.Linear
or not. So I simply included all of them just in case. That Mixtral config can, indeed, be made somewhat shorter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty clean already ! Thanks so much for working on this and converting Mixtral using AQLM format - thanks also for sharing the Google Colab ! Amazing work !
I assume the method works also on a T4 since you shared a colab demo, would you mind adding simple tests?
You can simply copy paste the tests from AWQ: https://github.com/huggingface/transformers/blob/main/tests/quantization/autoawq/test_awq.py and simply have Config tests and very simple model tests that tests that the model has been successfully converted to Aqlm format + a generation test. I would also add a simple test to make sure the model loads well on CPU
Can you also share some insights on generation speed for CPU & GPU ? 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last thing; could you add installation instructions on our testing dockerfile: https://github.com/huggingface/transformers/blob/main/docker/transformers-all-latest-gpu/Dockerfile
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @oobabooga this might be of your interest ! |
I updated the docked recipe and added tests, but they are skipped because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! I left two minor comments ! Can you also run make fixup
? This will redirect you to run make fix-copies
which should fix the tests !
return True | ||
|
||
|
||
def _replace_with_aqlm_linear( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can move this method and make it public under integrations/aqlm.py
and import locally the method inside _process_model_before_weight_loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/quantization/aqlm/test_aqlm.py
Outdated
|
||
|
||
@require_torch_gpu | ||
class AwqConfigTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class AwqConfigTest(unittest.TestCase): | |
class AqlmConfigTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very clean ! Thanks so much for the integration !
As discussed with @ArthurZucker offline, we could leverage # Copied from
on the tests but this is clearly not a blocker for me and we can merge as is! Looking forward to the integration !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for the clean PR 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also left one nit ! I tested it and seems to work only for python >=3.10
@BlackSamorez on a google colab env the inference script works great, however on my VM, on a python 3.10 env with latest torch + cuda11.8 I constantly get: Traceback (most recent call last):
File "/transformers/scratch.py", line 11, in <module>
output = quantized_model.generate(tokenizer("", return_tensors="pt")["input_ids"].cuda(), max_new_tokens=10)
File "/miniconda3/envs/aqlm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/transformers/src/transformers/generation/utils.py", line 1495, in generate
return self.greedy_search(
File "/transformers/src/transformers/generation/utils.py", line 2366, in greedy_search
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
RuntimeError: CUDA error: device kernel image is invalid
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. Do you have an idea what might be wrong here? |
The only difference I see between the colab instance and mine is the CUDA version, I'll update it to 12.1 and loop back here |
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Arthur <[email protected]>
@younesbelkada import os
from typing import Optional
import torch
from torch.utils.cpp_extension import load
CUDA_FOLDER = os.path.dirname(os.path.abspath(__file__))
CUDA_KERNEL = load(
name="codebook_cuda",
sources=[os.path.join(CUDA_FOLDER, "cuda_kernel.cpp"), os.path.join(CUDA_FOLDER, "cuda_kernel.cu")],
) Maybe your |
CUDA 11.8 seems to work fine on my machine on an a100 GPU. |
FYI: I've released |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the integration !
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Looks like some network error occured |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @BlackSamorez !
🤗 🚀 |
* aqlm init * calibration and dtypes * docs * Readme update * is_aqlm_available * Simpler link in docs * Test TODO real reference * init _import_structure fix * AqlmConfig autodoc * integration aqlm * integrations in tests * docstring fix * legacy typing * Less typings * More kernels information * Performance -> Accuracy * correct tests * remoced multi-gpu test * Update docs/source/en/quantization.md Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/utils/quantization_config.py Co-authored-by: Arthur <[email protected]> * Brought back multi-gpu tests * Update src/transformers/integrations/aqlm.py Co-authored-by: Marc Sun <[email protected]> * Update tests/quantization/aqlm_integration/test_aqlm.py Co-authored-by: Marc Sun <[email protected]> --------- Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Marc Sun <[email protected]>
Hi @BlackSamorez ! |
@younesbelkada AttributeError Traceback (most recent call last)
[<ipython-input-2-68b1b199d504>](https://localhost:8080/#) in <cell line: 3>()
1 from transformers import AutoTokenizer, AutoModelForCausalLM
2
----> 3 quantized_model = AutoModelForCausalLM.from_pretrained(
4 "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch",
5 torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True,
4 frames
[/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
565 elif type(config) in cls._model_mapping.keys():
566 model_class = _get_model_class(config, cls._model_mapping)
--> 567 return model_class.from_pretrained(
568 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
569 )
[/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
3561
3562 if hf_quantizer is not None:
-> 3563 hf_quantizer.postprocess_model(model)
3564 model.hf_quantizer = hf_quantizer
3565
[/usr/local/lib/python3.10/dist-packages/transformers/quantizers/base.py](https://localhost:8080/#) in postprocess_model(self, model, **kwargs)
177 The keyword arguments that are passed along `_process_model_after_weight_loading`.
178 """
--> 179 return self._process_model_after_weight_loading(model, **kwargs)
180
181 @abstractmethod
[/usr/local/lib/python3.10/dist-packages/transformers/quantizers/quantizer_aqlm.py](https://localhost:8080/#) in _process_model_after_weight_loading(self, model, **kwargs)
78
79 def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
---> 80 model._is_quantized_training_enabled = False
81 return model
82
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __setattr__(self, name, value)
1745 buffers[name] = value
1746 else:
-> 1747 super().__setattr__(name, value)
1748
1749 def __delattr__(self, name):
AttributeError: can't set attribute '_is_quantized_training_enabled' |
* aqlm init * calibration and dtypes * docs * Readme update * is_aqlm_available * Simpler link in docs * Test TODO real reference * init _import_structure fix * AqlmConfig autodoc * integration aqlm * integrations in tests * docstring fix * legacy typing * Less typings * More kernels information * Performance -> Accuracy * correct tests * remoced multi-gpu test * Update docs/source/en/quantization.md Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/utils/quantization_config.py Co-authored-by: Arthur <[email protected]> * Brought back multi-gpu tests * Update src/transformers/integrations/aqlm.py Co-authored-by: Marc Sun <[email protected]> * Update tests/quantization/aqlm_integration/test_aqlm.py Co-authored-by: Marc Sun <[email protected]> --------- Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Marc Sun <[email protected]>
What does this PR do?
Fixes Vahe1994/AQLM#11
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@younesbelkada