Skip to content

Commit

Permalink
Gptq refactor (#1770)
Browse files Browse the repository at this point in the history
* refactor gptq with prepare and convert API

Signed-off-by: xin3he <[email protected]>

* fix bug

Signed-off-by: xin3he <[email protected]>

* update quantizer and model relationship

Signed-off-by: xin3he <[email protected]>

* fix bug

Signed-off-by: xin3he <[email protected]>

* add UT for quantize API

Signed-off-by: xin3he <[email protected]>

---------

Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored May 7, 2024
1 parent 5f3f388 commit 84d7055
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 91 deletions.
132 changes: 61 additions & 71 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq):
return scale * (q - zero)


class GPTQuantizer(object):
class RAWGPTQuantizer(object):
"""Main API for GPTQ algorithm.
Please refer to:
Expand All @@ -195,15 +195,14 @@ def __init__(
self,
model,
weight_config={},
dataloader=None,
nsamples=128,
use_max_length=True,
max_seq_length=2048,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path="",
run_fn=None,
dataloader=None,
*args,
**kwargs,
):
Expand All @@ -226,7 +225,6 @@ def __init__(
export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
model_path (str): Model path that is used to load state_dict per layer.
run_fn: a function to run model inference for collecting input information.
device: cpu or cuda
"""
# model
Expand Down Expand Up @@ -271,9 +269,7 @@ def __init__(
self.dataloader_original = dataloader
self.dataloader = []
self.nsamples = nsamples
self.run_fn = run_fn
self.run_args = kwargs.get("run_args", None)
if run_fn is None:
if dataloader is not None:
self.prepare_dataloader()

def prepare_dataloader(self):
Expand Down Expand Up @@ -489,7 +485,7 @@ def track_hidden_states(self, data):
return data[0]

@torch.no_grad()
def pre_quantization(self):
def prepare_for_calibration(self):
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
try:
self.cache_key_arguments = {
Expand Down Expand Up @@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
if not self.use_layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
)

# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
logger.info("Collecting calibration inputs by running the run_fn provided by user.")
if self.run_fn:
if self.run_args:
self.run_fn(self.model, *self.run_args)
accelerator.mark_step()
else:
self.run_fn(self.model)
accelerator.mark_step()
else:
for batch in tqdm(self.dataloader):
if not self.use_layer_wise:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
elif isinstance(batch, dict):
self.model(**batch)
else:
self.model(batch)
except ValueError:
pass
@torch.no_grad()
def remove_prepare_for_calibration(self):
# output inp data shape
logger.info("All calibration data's shape =>")
# check all hidden_states shape
Expand All @@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
logger.info("Done.")

# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
if not self.use_layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
Expand Down Expand Up @@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
# Step1: prepare quantization (calibration datasets)

logger.info("Begin ====>")
self.pre_quantization()
model_path = self.model_path

# Step2: run gptq quantization in a transformer block-wise manner.
Expand Down Expand Up @@ -1144,41 +1118,57 @@ def ready(self):
return torch.all(self.scale != 0)


def gptq_quantize(
model,
weight_config={},
dataloader=None,
nsamples=128,
max_seq_length=2048,
use_max_length=True,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path=None,
run_fn=None,
run_args=None,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
if use_layer_wise:
assert model_path is not None, "model_path should not be None when use layer wise mode"
from .gptq import GPTQuantizer

gptq_quantizer = GPTQuantizer(
from neural_compressor.torch.algorithms import Quantizer as INCQuantizer


class GPTQuantizer(INCQuantizer):
def __init__(self, quant_config={}):
"""Init a RTNQuantizer object.
Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)

@torch.no_grad()
def prepare(
self,
model,
weight_config,
dataloader,
nsamples,
use_max_length,
max_seq_length,
device,
export_compressed_model=export_compressed_model,
use_layer_wise=use_layer_wise,
model_path=model_path,
run_fn=run_fn,
run_args=run_args,
)
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
logger.info("GPTQ quantizing done.")
return fp32_modified_model, gptq_config
nsamples=128,
max_seq_length=2048,
use_max_length=True,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path=None,
*args,
**kwargs,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
if use_layer_wise:
assert model_path is not None, "model_path should not be None when use layer wise mode"

self.gptq_quantizer = RAWGPTQuantizer(
model,
weight_config=self.quant_config,
nsamples=nsamples,
use_max_length=use_max_length,
max_seq_length=max_seq_length,
device=device,
export_compressed_model=export_compressed_model,
use_layer_wise=use_layer_wise,
model_path=model_path,
)
self.gptq_quantizer.prepare_for_calibration()
return self.gptq_quantizer.model

@torch.no_grad()
def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
return q_model
26 changes: 17 additions & 9 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ def rtn_entry(
@register_algo(GPTQ)
@torch.no_grad()
def gptq_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str, callable], GPTQConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the GPTQ algorithm.")
from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer

# rebuild weight_config for gptq_quantize function
weight_config = {}
Expand Down Expand Up @@ -106,12 +110,16 @@ def gptq_entry(
}
)
kwargs.pop("example_inputs")
kwargs.pop("mode") # TODO: will be removed after GPTQ refactoring

logger.warning("lm_head in transformer model is skipped by GPTQ")
model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs)
# Assign the gptq config as an attribute of model
model._gptq_quantization_perm = quantization_perm
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = GPTQuantizer(quant_config=weight_config)
model = quantizer.execute(model, mode=mode, *args, **kwargs)
if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
return model


Expand All @@ -123,7 +131,7 @@ def static_quant_entry(
configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the static quant algorithm.")
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer
Expand Down Expand Up @@ -333,7 +341,7 @@ def autoround_quantize_entry(
configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer

Expand Down
Loading

0 comments on commit 84d7055

Please sign in to comment.