From cfefe8f3e92bbf91ffab603398cef0f8ea8f836c Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 17:56:07 +0800 Subject: [PATCH 1/8] layer_config should be a parameter in AutoRoundQuantizeConfig --- gptqmodel/models/base.py | 21 ++++++--------------- gptqmodel/quantization/config.py | 2 ++ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index b56a80a47..cd0c44bb6 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -258,17 +258,8 @@ def quantize( if isinstance(self.quantize_config, AutoRoundQuantizeConfig): from auto_round import AutoRound - from transformers import modeling_utils - weight_config = {} - for n, m in self.model.named_modules(): - if isinstance(m, torch.nn.Linear) or isinstance(m, modeling_utils.Conv1D): - if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: - weight_config[n] = {"data_type": "fp"} - print( - f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to gptqmodel") - if self.quantize_config.lm_head: - weight_config['lm_head'] = {"data_type": "int"} + self.quantize_config.layer_config['lm_head'] = {"data_type": "int"} import torch.nn.functional as F from torch.utils.data import DataLoader @@ -317,17 +308,17 @@ def collate_batch(batch): low_gpu_mem_usage=self.quantize_config.low_gpu_mem_usage, seed=self.quantize_config.seed, gradient_accumulate_steps=self.quantize_config.gradient_accumulate_steps, - scale_dtype=self.quantize_config.scale_dtype, weight_config=weight_config, + scale_dtype=self.quantize_config.scale_dtype, layer_config=self.quantize_config.layer_config, enable_minmax_tuning=self.quantize_config.enable_minmax_tuning) model, _ = self.autoround.quantize() quantizers = {} - for key in self.autoround.weight_config: - info = self.autoround.weight_config[key] + for key in self.autoround.layer_config: + info = self.autoround.layer_config[key] if not check_to_quantized(info): continue - quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), info["g_idx"]) + quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32)) self.qlinear_kernel = pack_model( model=self.model, @@ -335,7 +326,7 @@ def collate_batch(batch): bits=self.quantize_config.bits, dynamic=self.quantize_config.dynamic, group_size=self.quantize_config.group_size, - backend=BACKEND.TRITON, + backend=BACKEND.AUTO, desc_act=self.quantize_config.desc_act, force_layer_back_to_cpu=True, format=self.quantize_config.format, diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 4fd99cc5e..52d9cd810 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -331,6 +331,7 @@ def to_dict(self): @dataclass class AutoRoundQuantizeConfig(QuantizeConfig): + layer_config: dict = {}, enable_full_range: bool = False ##for symmetric, TODO support later batch_size: int = 1 amp: bool = True @@ -355,6 +356,7 @@ def to_dict(self): # inject auto-round specific meta data self.meta_set("auto_round", pkg_version(PKG_AUTO_ROUND)) self.meta_set("enable_full_range", self.enable_full_range) + self.meta_set("layer_config", self.layer_config) self.meta_set("batch_size", self.batch_size) self.meta_set("amp", self.amp) self.meta_set("lr_scheduler", self.lr_scheduler) From 479b8f8906672520d507db512e96400d89d0cfd6 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 17:56:28 +0800 Subject: [PATCH 2/8] fix gidx none issue --- gptqmodel/models/base.py | 2 +- gptqmodel/utils/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index cd0c44bb6..d73f24e9f 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -318,7 +318,7 @@ def collate_batch(batch): info = self.autoround.layer_config[key] if not check_to_quantized(info): continue - quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32)) + quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), None) self.qlinear_kernel = pack_model( model=self.model, diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 0a859ef42..e0af8d249 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -276,7 +276,7 @@ def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar): layers[name].to(CPU), scale.to(CPU), zero.to(CPU), - g_idx.to(CPU), + g_idx.to(CPU) if g_idx is not None else None, ) if QuantLinear is MarlinQuantLinear: qlayers[name].pack(layers[name], scale) From e6dd9a13788bed83270fd42e53c9b897d173151d Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 17:57:44 +0800 Subject: [PATCH 3/8] auto_round now supports more formats --- gptqmodel/quantization/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 52d9cd810..5364c6e84 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -61,6 +61,9 @@ class QUANT_METHOD: }, QUANT_METHOD.AUTO_ROUND: { FORMAT.GPTQ, + FORMAT.GPTQ_V2, + FORMAT.MARLIN, + FORMAT.BITBLAS, } } From f92ea3c9e0f0bfb25091ec8f5b059afc5562b394 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 18:55:47 +0800 Subject: [PATCH 4/8] cleanup --- gptqmodel/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 5364c6e84..d8f4fd09b 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -334,7 +334,7 @@ def to_dict(self): @dataclass class AutoRoundQuantizeConfig(QuantizeConfig): - layer_config: dict = {}, + layer_config: dict = {} enable_full_range: bool = False ##for symmetric, TODO support later batch_size: int = 1 amp: bool = True From e51f8cc8fe2cd3126682b46c16caea2871099bf3 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 19:01:48 +0800 Subject: [PATCH 5/8] use default_factory --- gptqmodel/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index d8f4fd09b..79b23d941 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -334,7 +334,7 @@ def to_dict(self): @dataclass class AutoRoundQuantizeConfig(QuantizeConfig): - layer_config: dict = {} + layer_config: dict = field(default_factory=dict) enable_full_range: bool = False ##for symmetric, TODO support later batch_size: int = 1 amp: bool = True From 88ece8a0ffe13abaa7b48202d1069a72986b531e Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 19:54:02 +0800 Subject: [PATCH 6/8] check auto_round version --- gptqmodel/models/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d73f24e9f..02761c6c9 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -40,6 +40,7 @@ simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes) from ..version import __version__ from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS +from packaging import version logger = logging.getLogger(__name__) handler = logging.StreamHandler() @@ -258,6 +259,11 @@ def quantize( if isinstance(self.quantize_config, AutoRoundQuantizeConfig): from auto_round import AutoRound + from auto_round import __version__ as auto_round_version + + if version.parse(auto_round_version) < version.parse("0.3.0"): + raise ValueError(f"AutoRound version must be >= 0.3.0. Current version: {auto_round_version}") + if self.quantize_config.lm_head: self.quantize_config.layer_config['lm_head'] = {"data_type": "int"} From 93dea77adbc298fa18e6140aa583fb5557043a37 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Thu, 15 Aug 2024 20:04:58 +0800 Subject: [PATCH 7/8] update dep version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8d43f0d19..09669b796 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,6 @@ packaging>=24.1 ninja>=1.11.1.1 protobuf>=4.25.3 intel_extension_for_transformers>=1.4.2 -auto-round==0.2 +auto-round==0.3 huggingface-hub>=0.24.2 lm_eval==0.4.3 From 2783cfdb3a4ee7bbc5621e4fcac17a6074ffddb7 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Thu, 15 Aug 2024 20:05:51 +0800 Subject: [PATCH 8/8] Update base.py --- gptqmodel/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 02761c6c9..acf41b41a 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -262,7 +262,7 @@ def quantize( from auto_round import __version__ as auto_round_version if version.parse(auto_round_version) < version.parse("0.3.0"): - raise ValueError(f"AutoRound version must be >= 0.3.0. Current version: {auto_round_version}") + raise ValueError(f"AutoRound version must be >= 0.3.0: actual = {auto_round_version}") if self.quantize_config.lm_head: self.quantize_config.layer_config['lm_head'] = {"data_type": "int"}