Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Compat autoround 0.3 #368

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
from ..version import __version__
from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS
from packaging import version

logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -258,17 +259,13 @@ def quantize(

if isinstance(self.quantize_config, AutoRoundQuantizeConfig):
from auto_round import AutoRound
from transformers import modeling_utils
weight_config = {}
for n, m in self.model.named_modules():
if isinstance(m, torch.nn.Linear) or isinstance(m, modeling_utils.Conv1D):
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
weight_config[n] = {"data_type": "fp"}
print(
f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to gptqmodel")
from auto_round import __version__ as auto_round_version

if version.parse(auto_round_version) < version.parse("0.3.0"):
raise ValueError(f"AutoRound version must be >= 0.3.0: actual = {auto_round_version}")

if self.quantize_config.lm_head:
weight_config['lm_head'] = {"data_type": "int"}
self.quantize_config.layer_config['lm_head'] = {"data_type": "int"}

import torch.nn.functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -317,25 +314,25 @@ def collate_batch(batch):
low_gpu_mem_usage=self.quantize_config.low_gpu_mem_usage,
seed=self.quantize_config.seed,
gradient_accumulate_steps=self.quantize_config.gradient_accumulate_steps,
scale_dtype=self.quantize_config.scale_dtype, weight_config=weight_config,
scale_dtype=self.quantize_config.scale_dtype, layer_config=self.quantize_config.layer_config,
enable_minmax_tuning=self.quantize_config.enable_minmax_tuning)

model, _ = self.autoround.quantize()

quantizers = {}
for key in self.autoround.weight_config:
info = self.autoround.weight_config[key]
for key in self.autoround.layer_config:
info = self.autoround.layer_config[key]
if not check_to_quantized(info):
continue
quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), info["g_idx"])
quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), None)

self.qlinear_kernel = pack_model(
model=self.model,
quantizers=quantizers,
bits=self.quantize_config.bits,
dynamic=self.quantize_config.dynamic,
group_size=self.quantize_config.group_size,
backend=BACKEND.TRITON,
backend=BACKEND.AUTO,
desc_act=self.quantize_config.desc_act,
force_layer_back_to_cpu=True,
format=self.quantize_config.format,
Expand Down
5 changes: 5 additions & 0 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class QUANT_METHOD:
},
QUANT_METHOD.AUTO_ROUND: {
FORMAT.GPTQ,
FORMAT.GPTQ_V2,
FORMAT.MARLIN,
FORMAT.BITBLAS,
}
}

Expand Down Expand Up @@ -331,6 +334,7 @@ def to_dict(self):

@dataclass
class AutoRoundQuantizeConfig(QuantizeConfig):
layer_config: dict = field(default_factory=dict)
enable_full_range: bool = False ##for symmetric, TODO support later
batch_size: int = 1
amp: bool = True
Expand All @@ -355,6 +359,7 @@ def to_dict(self):
# inject auto-round specific meta data
self.meta_set("auto_round", pkg_version(PKG_AUTO_ROUND))
self.meta_set("enable_full_range", self.enable_full_range)
self.meta_set("layer_config", self.layer_config)
self.meta_set("batch_size", self.batch_size)
self.meta_set("amp", self.amp)
self.meta_set("lr_scheduler", self.lr_scheduler)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar):
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
g_idx.to(CPU) if g_idx is not None else None,
)
if QuantLinear is MarlinQuantLinear:
qlayers[name].pack(layers[name], scale)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ packaging>=24.1
ninja>=1.11.1.1
protobuf>=4.25.3
intel_extension_for_transformers>=1.4.2
auto-round==0.2
auto-round==0.3
huggingface-hub>=0.24.2
lm_eval==0.4.3
Loading