Skip to content

Commit

Permalink
Support auto_round integration 3.x (#1810)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored May 30, 2024
1 parent 19ff13e commit a3a0650
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 292 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install git+https://github.com/intel/auto-round.git@ecca5349981044e1278773a251b3fc5c0a11fe7b
pip install auto-round
fi

# test deps
Expand Down
226 changes: 46 additions & 180 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import time
from typing import Union

import torch
from auto_round import AutoRound # pylint: disable=E0401
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
from auto_round.utils import get_block_names # pylint: disable=E0401
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import get_accelerator, logger

from .utility import CapturedDataloader, InputCaptureModule


class AutoRoundQuantizer(Quantizer):
def __init__(
self,
quant_config: dict = None,
quant_config: dict = {},
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
device=None,
lr_scheduler=None,
use_quant_input: bool = True,
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
Expand All @@ -46,7 +50,9 @@ def __init__(
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype="fp32",
data_type: str = "int",
scale_dtype: str = "fp16",
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Expand Down Expand Up @@ -86,17 +92,17 @@ def __init__(
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
"""
super().__init__(quant_config)
self.tokenizer = None
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.amp = amp
self.device = device
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
self.lr_scheduler = lr_scheduler
self.use_quant_input = use_quant_input
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
self.lr = lr
self.minmax_lr = minmax_lr
Expand All @@ -110,7 +116,7 @@ def __init__(
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = "int"
self.data_type = data_type
self.scale_dtype = scale_dtype

def prepare(self, model: torch.nn.Module, *args, **kwargs):
Expand All @@ -121,16 +127,23 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
Returns:
A prepared model.
"""
self.rounder = AutoRoundProcessor(
prepare_model = InputCaptureModule(model)
return prepare_model

def convert(self, model: torch.nn.Module, *args, **kwargs):
dataloader = CapturedDataloader(model.args_list, model.kwargs_list)
model = model.orig_model
rounder = AutoRound(
model=model,
tokenizer=None,
dataset=dataloader,
weight_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device=self.device,
lr_scheduler=self.lr_scheduler,
use_quant_input=self.use_quant_input,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
Expand All @@ -147,179 +160,32 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
data_type=self.data_type,
scale_dtype=self.scale_dtype,
)
self.rounder.prepare()
return model

def convert(self, model: torch.nn.Module, *args, **kwargs):
model, weight_config = self.rounder.convert()
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
return model


@torch.no_grad()
def get_autoround_default_run_fn(
model,
tokenizer,
dataset_name="NeelNanda/pile-10k",
n_samples=512,
seqlen=2048,
seed=42,
bs=8,
dataset_split: str = "train",
dataloader=None,
):
"""Perform calibration for quantization.
This method calibrates the model for quantization by processing a specified
number of samples from the calibration dataset. It ensures that the data is
properly formatted and feeds it to the model. If the number of samples processed
is less than the specified number, it logs a warning. If no samples are processed,
it logs an error and exits.
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
"""Generate a DataLoader for calibration using specified parameters.
Args:
n_samples (int): The number of samples to use for calibration.
tokenizer (Tokenizer): The tokenizer to use for tokenization.
seqlen (int): The exact sequence length. samples < seqlen will be dropped,
samples longer than seqlen will be truncated
dataset_name (str, optional): The name of the dataset or datasets separated by commas.
Defaults to "NeelNanda/pile-10k".
split (str, optional): The data split to use. Defaults to None.
seed (int, optional): The random seed for reproducibility. Defaults to 42.
bs (int, optional): The batch size. Defaults to 4.
n_samples (int, optional): The total number of samples to include. Defaults to 512.
Returns:
DataLoader: The DataLoader for the calibrated dataset.
"""
if dataloader is None:
get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"])
dataloader = get_dataloader(
tokenizer,
seqlen,
seed=seed,
bs=bs,
split=dataset_split,
dataset_name=dataset_name,
)
total_cnt = 0
for data in dataloader:
if data is None:
continue
if isinstance(data, torch.Tensor):
data_new = data.to(model.device)
input_ids = data_new
else:
data_new = {}
for key in data.keys():
data_new[key] = data[key].to(model.device)
input_ids = data_new["input_ids"]
# if input_ids.shape[-1] < seqlen:
# continue
if total_cnt + input_ids.shape[0] > n_samples:
input_ids = input_ids[: n_samples - total_cnt, ...]
try:
if isinstance(data_new, torch.Tensor):
model(data_new)
elif isinstance(data_new, dict):
model(**data_new)
else:
# Handle cases where data_new is neither a Tensor nor a dict
raise NotImplementedError(f"Handling not implemented for data type {type(data)}")
except Exception as error:
logger.error(error)
total_cnt += input_ids.shape[0]
if total_cnt >= n_samples:
break
if total_cnt == 0:
logger.error(
"no data has been cached, please provide more data with sequence length >= {} in the ".format(seqlen)
+ "dataloader or decease the sequence length."
)
exit()
elif total_cnt < n_samples:
logger.warning(
"Insufficient number of samples collected may affect the quantification. "
"Effective samples size: {}, Target sample size: {}".format(total_cnt, n_samples)
)


class AutoRoundProcessor(AutoRound):

def prepare(self):
"""Prepares a given model for quantization."""
# logger.info("cache block input")
self.start_time = time.time()
self.block_names = get_block_names(self.model)
if len(self.block_names) == 0:
logger.warning("could not find blocks, exit with original model")
return
if self.amp:
self.model = self.model.to(self.amp_dtype)
if not self.low_gpu_mem_usage:
self.model = self.model.to(self.device)
# inputs = self.cache_block_input(block_names[0], self.n_samples)

# cache block input
self.inputs = {}
self.tmp_block_name = self.block_names[0]
self._replace_forward()

def convert(self):
"""Converts a prepared model to a quantized model."""
self._recover_forward()
inputs = self.inputs[self.tmp_block_name]
del self.tmp_block_name

del self.inputs
if "input_ids" in inputs.keys():
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
total_samples = inputs["input_ids"].shape[dim]
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.qdq_weight_round(
self.model,
inputs,
self.block_names,
n_blocks=self.n_blocks,
device=self.device,
)
for n, m in self.model.named_modules():
if n in self.weight_config.keys():
if hasattr(m, "scale"):
self.weight_config[n]["scale"] = m.scale
self.weight_config[n]["zp"] = m.zp
if self.group_size <= 0:
self.weight_config[n]["g_idx"] = torch.tensor(
[0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
else:
self.weight_config[n]["g_idx"] = torch.tensor(
[i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
delattr(m, "scale")
delattr(m, "zp")
else:
self.weight_config[n]["data_type"] = "float"
if self.amp_dtype == torch.bfloat16:
self.weight_config[n]["data_type"] = "bfloat"
self.weight_config[n]["bits"] = 16
self.weight_config[n]["group_size"] = None
self.weight_config[n]["sym"] = None

end_time = time.time()
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")
## dump a summary
quantized_layers = []
unquantized_layers = []
for n, m in self.model.named_modules():
if isinstance(m, tuple(self.supported_types)):
if self.weight_config[n]["bits"] == 16:
unquantized_layers.append(n)
else:
quantized_layers.append(n)
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
)
if len(unquantized_layers) > 0:
summary_info += f", {unquantized_layers} have not been quantized"

logger.info(summary_info)
if len(unquantized_layers) > 0:
logger.info(f"Summary: {unquantized_layers} have not been quantized")
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

self.quantized = True
self.model = self.model.to(self.model_orig_dtype)
return self.model, self.weight_config
dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
)
return dataloader
29 changes: 29 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,32 @@ def _hook(module, inputs, outputs):
for h in hook_list:
h.remove()
return total_values


class CapturedDataloader:
def __init__(self, args_list, kwargs_list) -> None:
self.args_list = args_list
self.kwargs_list = kwargs_list

def __iter__(self):
for args, kwargs in zip(self.args_list, self.kwargs_list):
if not args:
yield kwargs
elif not kwargs:
yield args
else:
yield args, kwargs


class InputCaptureModule(torch.nn.Module):

def __init__(self, model) -> None:
super().__init__()
self.args_list = []
self.kwargs_list = []
self.orig_model = model

def forward(self, *args, **kwargs):
with torch.no_grad():
self.args_list.append(args)
self.kwargs_list.append(kwargs)
4 changes: 2 additions & 2 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def autoround_quantize_entry(
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
lr_scheduler = quant_config.lr_scheduler
use_quant_input = quant_config.use_quant_input
enable_quanted_input = quant_config.enable_quanted_input
enable_minmax_tuning = quant_config.enable_minmax_tuning
lr = quant_config.lr
minmax_lr = quant_config.minmax_lr
Expand All @@ -474,7 +474,7 @@ def autoround_quantize_entry(
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
Expand Down
Loading

0 comments on commit a3a0650

Please sign in to comment.