Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ integration #25062

Merged
merged 50 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d963b97
GTPQ integration
SunMarc Jul 24, 2023
93f0d84
Add tests for gptq
SunMarc Jul 24, 2023
380baea
support for more quantization model
SunMarc Jul 25, 2023
810d537
fix style
SunMarc Jul 25, 2023
c3f5248
typo
SunMarc Jul 25, 2023
fc70ef4
fix method
SunMarc Jul 25, 2023
6a04bb8
Update src/transformers/modeling_utils.py
SunMarc Jul 25, 2023
271dab6
add dataclass and fix quantization_method
SunMarc Jul 25, 2023
992881e
fix doc
SunMarc Jul 26, 2023
3c2d940
Update tests/quantization/gptq/test_gptq.py
SunMarc Jul 26, 2023
9bbb336
Apply suggestions from code review
SunMarc Jul 26, 2023
0134c79
modify dataclass
SunMarc Jul 26, 2023
a2a7f5d
add gtpqconfig import
SunMarc Jul 26, 2023
70e1416
fix typo
SunMarc Jul 26, 2023
0e2014b
fix tests
SunMarc Jul 26, 2023
69e3c88
remove dataset as req arg
SunMarc Jul 26, 2023
cb46d75
remove tokenizer import
SunMarc Jul 26, 2023
9a3cafd
add offload cpu quantization test
SunMarc Jul 26, 2023
27e9b79
fix check dataset
SunMarc Jul 26, 2023
f47ecb4
modify dockerfile
SunMarc Jul 26, 2023
19d05d3
protect trainer
SunMarc Jul 26, 2023
76dffe2
style
SunMarc Jul 26, 2023
0f61037
test for config
SunMarc Jul 26, 2023
b0eccd5
add more log
SunMarc Jul 27, 2023
2e7a025
overwrite torch_dtype
SunMarc Jul 27, 2023
a07126a
draft doc
SunMarc Jul 27, 2023
c9d3f26
modify quantization_config docstring
SunMarc Jul 31, 2023
ecce1da
fix class name in docstring
SunMarc Jul 31, 2023
2226184
Apply suggestions from code review
SunMarc Jul 31, 2023
eff99cb
more warning
SunMarc Jul 31, 2023
159cf87
fix 8bit kwargs tests
SunMarc Jul 31, 2023
98db723
peft compatibility
SunMarc Jul 31, 2023
0144760
remove var
SunMarc Aug 1, 2023
fd8d70c
fix is_gptq_quantized
SunMarc Aug 1, 2023
0f96fb2
Merge branch 'main' into gptq_integration
SunMarc Aug 1, 2023
be19916
remove is_gptq_quantized
SunMarc Aug 2, 2023
9e8f487
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 2, 2023
4b4336e
fix wrap
SunMarc Aug 2, 2023
42d0049
Update src/transformers/modeling_utils.py
SunMarc Aug 8, 2023
a9658e2
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 8, 2023
62aa293
add exllama
SunMarc Aug 9, 2023
39137eb
skip test
SunMarc Aug 9, 2023
f23ce7e
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 9, 2023
0b0633b
overwrite float16
SunMarc Aug 9, 2023
c3c4a16
style
SunMarc Aug 9, 2023
a45b5b0
fix skip test
SunMarc Aug 9, 2023
69c8fce
Apply suggestions from code review
SunMarc Aug 10, 2023
bf98799
fix docsting formatting
SunMarc Aug 10, 2023
7adf9cb
add doc
SunMarc Aug 10, 2023
c93d1d0
better test
SunMarc Aug 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss

from . import AutoTokenizer
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
Expand Down Expand Up @@ -64,6 +65,7 @@
download_url,
has_file,
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_offline_mode,
is_optimum_available,
Expand All @@ -75,7 +77,7 @@
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled
from .utils.quantization_config import BitsAndBytesConfig
from .utils.quantization_config import AutoGPTQConfig, BitsAndBytesConfig, QuantizationMethod
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
from .utils.versions import require_version_core


Expand Down Expand Up @@ -2256,13 +2258,17 @@ def from_pretrained(
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)

if quantization_config is None:
quantization_method_from_args = None
if quantization_config is not None:
quantization_method_from_args = quantization_config.get("quant_method", QuantizationMethod.BITS_AND_BYTES)

if quantization_config is None and (load_in_8bit or load_in_4bit):
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit},
return_unused_kwargs=True,
**kwargs,
)
elif quantization_config is not None:
elif quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES:
load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit

Expand Down Expand Up @@ -2344,15 +2350,53 @@ def from_pretrained(
else:
model_kwargs = kwargs

if is_8bit_serializable and quantization_config is not None and load_in_8bit:
if hasattr(config, "quantization_config"):
quantizer = None
quantization_method_from_config = None
if hasattr(config, "quantization_config"):
quantization_method_from_config = config.quantization_config.get(
"quant_method", QuantizationMethod.BITS_AND_BYTES
)

if quantization_method_from_config == QuantizationMethod.GPTQ and quantization_method_from_args is not None:
quantization_method_from_args = None
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
" `quantization_config` attribute and has already quantized weights. We will not perform quantization"
"with the given `quantization config` that you have passed."
)
if (
quantization_method_from_args == QuantizationMethod.GPTQ
or quantization_method_from_config == QuantizationMethod.GPTQ
):
if not (is_optimum_available() and is_auto_gptq_available()):
raise ImportError(
"Loading GTPQ quantized model requires optimum library : `pip install optimum` and auto-gptq library 'pip install auto-gptq'"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
)
else:
# Need to protect the import
from optimum.gptq import GPTQQuantizer
if quantization_method_from_config == QuantizationMethod.GPTQ:
quantization_config = AutoGPTQConfig.from_dict(config.quantization_config)
torch_dtype = config.torch_dtype
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())

if (
is_8bit_serializable
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
and load_in_8bit
):
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
" `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the"
" one you passed to `from_pretrained`."
)
config.quantization_config = quantization_config
elif is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
elif (
is_8bit_serializable
and not load_in_8bit
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
quantization_config = config.quantization_config
if isinstance(quantization_config, dict):
quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False)
Expand Down Expand Up @@ -2382,7 +2426,11 @@ def from_pretrained(
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True

elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
elif (
not is_8bit_serializable
and not load_in_8bit
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
Expand Down Expand Up @@ -2767,6 +2815,8 @@ def from_pretrained(
"All non-linear modules will be loaded in full precision."
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
)
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.convert_model(model)

if isinstance(device_map, str):
special_dtypes = {}
Expand Down Expand Up @@ -2962,6 +3012,11 @@ def from_pretrained(
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)

if quantization_method_from_args == QuantizationMethod.GPTQ:
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=True)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
quantizer.quantize_model(model, tokenizer)
model.config.quantization_config = AutoGPTQConfig.from_dict(quantizer.to_dict())

if output_loading_info:
if loading_info is None:
loading_info = {
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .utils import (
is_accelerate_available,
is_apex_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
is_cython_available,
Expand Down Expand Up @@ -770,6 +771,13 @@ def require_optimum(test_case):
return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)


def require_auto_gptq(test_case):
"""
Decorator for auto_gptq dependency
"""
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down Expand Up @@ -554,6 +555,10 @@ def is_optimum_available():
return _optimum_available


def is_auto_gptq_available():
return _auto_gptq_available


def is_optimum_neuron_available():
return _optimum_available and _is_package_available("optimum.neuron")

Expand Down
Loading