Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auto_round integration 3.x #1810

Merged
merged 50 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
32a6612
adapt v0.2
Kaihui-intel May 22, 2024
961cd69
clean ut
Kaihui-intel May 22, 2024
13e7455
enhance packing model ut
Kaihui-intel May 22, 2024
6e399c7
fix tmp_dtype
Kaihui-intel May 22, 2024
60d2c5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
4c42048
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 22, 2024
86f41da
support lm_head quant
Kaihui-intel May 23, 2024
6be70fa
support xpu packing
Kaihui-intel May 23, 2024
78af942
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
3f3b2fd
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 23, 2024
d868e96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
8b54922
use new prepare/convert
Kaihui-intel May 24, 2024
8939cf7
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 24, 2024
23be27f
update config
Kaihui-intel May 24, 2024
92bfffa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
f585690
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 24, 2024
de11f0e
minor fix
Kaihui-intel May 24, 2024
194026e
Merge branch 'master' of https://github.com/intel/neural-compressor
Kaihui-intel May 24, 2024
c98da6f
Merge branch 'master' into kaihui/ar_v02_3x
Kaihui-intel May 24, 2024
22f2ac4
update entry
Kaihui-intel May 24, 2024
e8696f3
rm xpu&update autoround commit
Kaihui-intel May 24, 2024
65d5401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
76b0bcd
remove autoround run_fn
Kaihui-intel May 24, 2024
c05407e
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 24, 2024
e450d02
remove pack_model
Kaihui-intel May 24, 2024
82c9f4a
refine captured dataloader
Kaihui-intel May 24, 2024
b03cc7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
ffe5ac2
rebase
Kaihui-intel May 24, 2024
2513392
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
32fb250
fix
Kaihui-intel May 24, 2024
a4e4dbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
2c9540a
use auto-round commit
Kaihui-intel May 27, 2024
42ad3f1
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 27, 2024
e6c2fbc
update dependency
Kaihui-intel May 27, 2024
ff349e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
493c0ba
update dependency
Kaihui-intel May 27, 2024
8f7e778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
43e8581
update commit
Kaihui-intel May 27, 2024
9f134f2
update commit
Kaihui-intel May 27, 2024
9e00cf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
13db80c
fix _ipex_available
Kaihui-intel May 27, 2024
dfbbfbb
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 27, 2024
49f8330
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
2a44ccb
reset torch
Kaihui-intel May 27, 2024
4c78d60
Merge branch 'kaihui/ar_v02_3x' of https://github.com/intel/neural-co…
Kaihui-intel May 27, 2024
474f7c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
e98608e
add no_grad forward
Kaihui-intel May 30, 2024
5483762
add no_grad for run_fn
Kaihui-intel May 30, 2024
b2f296c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
a8fbd53
remove commit version
Kaihui-intel May 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 46 additions & 180 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import time
from typing import Union

import torch
from auto_round import AutoRound # pylint: disable=E0401
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
from auto_round.utils import get_block_names # pylint: disable=E0401
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import get_accelerator, logger

from .utility import CapturedDataloader, InputCaptureModule


class AutoRoundQuantizer(Quantizer):
def __init__(
self,
quant_config: dict = None,
quant_config: dict = {},
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
device=None,
lr_scheduler=None,
use_quant_input: bool = True,
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
Expand All @@ -46,7 +50,9 @@ def __init__(
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype="fp32",
data_type: str = "int",
scale_dtype: str = "fp16",
**kwargs,
):
"""Init a AutQRoundQuantizer object.

Expand Down Expand Up @@ -86,17 +92,17 @@ def __init__(
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
"""
super().__init__(quant_config)
self.tokenizer = None
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.amp = amp
self.device = device
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
self.lr_scheduler = lr_scheduler
self.use_quant_input = use_quant_input
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
self.lr = lr
self.minmax_lr = minmax_lr
Expand All @@ -110,7 +116,7 @@ def __init__(
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = "int"
self.data_type = data_type
self.scale_dtype = scale_dtype

def prepare(self, model: torch.nn.Module, *args, **kwargs):
Expand All @@ -121,16 +127,23 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
Returns:
A prepared model.
"""
self.rounder = AutoRoundProcessor(
prepare_model = InputCaptureModule(model)
return prepare_model

def convert(self, model: torch.nn.Module, *args, **kwargs):
dataloader = CapturedDataloader(model.args_list, model.kwargs_list)
model = model.orig_model
rounder = AutoRound(
model=model,
tokenizer=None,
dataset=dataloader,
weight_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device=self.device,
lr_scheduler=self.lr_scheduler,
use_quant_input=self.use_quant_input,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
Expand All @@ -147,179 +160,32 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
data_type=self.data_type,
scale_dtype=self.scale_dtype,
)
self.rounder.prepare()
return model

def convert(self, model: torch.nn.Module, *args, **kwargs):
model, weight_config = self.rounder.convert()
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
return model


@torch.no_grad()
def get_autoround_default_run_fn(
model,
tokenizer,
dataset_name="NeelNanda/pile-10k",
n_samples=512,
seqlen=2048,
seed=42,
bs=8,
dataset_split: str = "train",
dataloader=None,
):
"""Perform calibration for quantization.

This method calibrates the model for quantization by processing a specified
number of samples from the calibration dataset. It ensures that the data is
properly formatted and feeds it to the model. If the number of samples processed
is less than the specified number, it logs a warning. If no samples are processed,
it logs an error and exits.
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
"""Generate a DataLoader for calibration using specified parameters.

Args:
n_samples (int): The number of samples to use for calibration.
tokenizer (Tokenizer): The tokenizer to use for tokenization.
seqlen (int): The exact sequence length. samples < seqlen will be dropped,
samples longer than seqlen will be truncated
dataset_name (str, optional): The name of the dataset or datasets separated by commas.
Defaults to "NeelNanda/pile-10k".
split (str, optional): The data split to use. Defaults to None.
seed (int, optional): The random seed for reproducibility. Defaults to 42.
bs (int, optional): The batch size. Defaults to 4.
n_samples (int, optional): The total number of samples to include. Defaults to 512.

Returns:
DataLoader: The DataLoader for the calibrated dataset.
"""
if dataloader is None:
get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"])
dataloader = get_dataloader(
tokenizer,
seqlen,
seed=seed,
bs=bs,
split=dataset_split,
dataset_name=dataset_name,
)
total_cnt = 0
for data in dataloader:
if data is None:
continue
if isinstance(data, torch.Tensor):
data_new = data.to(model.device)
input_ids = data_new
else:
data_new = {}
for key in data.keys():
data_new[key] = data[key].to(model.device)
input_ids = data_new["input_ids"]
# if input_ids.shape[-1] < seqlen:
# continue
if total_cnt + input_ids.shape[0] > n_samples:
input_ids = input_ids[: n_samples - total_cnt, ...]
try:
if isinstance(data_new, torch.Tensor):
model(data_new)
elif isinstance(data_new, dict):
model(**data_new)
else:
# Handle cases where data_new is neither a Tensor nor a dict
raise NotImplementedError(f"Handling not implemented for data type {type(data)}")
except Exception as error:
logger.error(error)
total_cnt += input_ids.shape[0]
if total_cnt >= n_samples:
break
if total_cnt == 0:
logger.error(
"no data has been cached, please provide more data with sequence length >= {} in the ".format(seqlen)
+ "dataloader or decease the sequence length."
)
exit()
elif total_cnt < n_samples:
logger.warning(
"Insufficient number of samples collected may affect the quantification. "
"Effective samples size: {}, Target sample size: {}".format(total_cnt, n_samples)
)


class AutoRoundProcessor(AutoRound):

def prepare(self):
"""Prepares a given model for quantization."""
# logger.info("cache block input")
self.start_time = time.time()
self.block_names = get_block_names(self.model)
if len(self.block_names) == 0:
logger.warning("could not find blocks, exit with original model")
return
if self.amp:
self.model = self.model.to(self.amp_dtype)
if not self.low_gpu_mem_usage:
self.model = self.model.to(self.device)
# inputs = self.cache_block_input(block_names[0], self.n_samples)

# cache block input
self.inputs = {}
self.tmp_block_name = self.block_names[0]
self._replace_forward()

def convert(self):
"""Converts a prepared model to a quantized model."""
self._recover_forward()
inputs = self.inputs[self.tmp_block_name]
del self.tmp_block_name

del self.inputs
if "input_ids" in inputs.keys():
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
total_samples = inputs["input_ids"].shape[dim]
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.qdq_weight_round(
self.model,
inputs,
self.block_names,
n_blocks=self.n_blocks,
device=self.device,
)
for n, m in self.model.named_modules():
if n in self.weight_config.keys():
if hasattr(m, "scale"):
self.weight_config[n]["scale"] = m.scale
self.weight_config[n]["zp"] = m.zp
if self.group_size <= 0:
self.weight_config[n]["g_idx"] = torch.tensor(
[0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
else:
self.weight_config[n]["g_idx"] = torch.tensor(
[i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
delattr(m, "scale")
delattr(m, "zp")
else:
self.weight_config[n]["data_type"] = "float"
if self.amp_dtype == torch.bfloat16:
self.weight_config[n]["data_type"] = "bfloat"
self.weight_config[n]["bits"] = 16
self.weight_config[n]["group_size"] = None
self.weight_config[n]["sym"] = None

end_time = time.time()
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")
## dump a summary
quantized_layers = []
unquantized_layers = []
for n, m in self.model.named_modules():
if isinstance(m, tuple(self.supported_types)):
if self.weight_config[n]["bits"] == 16:
unquantized_layers.append(n)
else:
quantized_layers.append(n)
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
)
if len(unquantized_layers) > 0:
summary_info += f", {unquantized_layers} have not been quantized"

logger.info(summary_info)
if len(unquantized_layers) > 0:
logger.info(f"Summary: {unquantized_layers} have not been quantized")
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

self.quantized = True
self.model = self.model.to(self.model_orig_dtype)
return self.model, self.weight_config
dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
)
return dataloader
29 changes: 29 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,32 @@ def _hook(module, inputs, outputs):
for h in hook_list:
h.remove()
return total_values


class CapturedDataloader:
def __init__(self, args_list, kwargs_list) -> None:
self.args_list = args_list
self.kwargs_list = kwargs_list

def __iter__(self):
for args, kwargs in zip(self.args_list, self.kwargs_list):
if not args:
yield kwargs
elif not kwargs:
yield args
else:
yield args, kwargs


class InputCaptureModule(torch.nn.Module):

def __init__(self, model) -> None:
super().__init__()
self.args_list = []
self.kwargs_list = []
self.orig_model = model

def forward(self, *args, **kwargs):
with torch.no_grad():
self.args_list.append(args)
self.kwargs_list.append(kwargs)
4 changes: 2 additions & 2 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def autoround_quantize_entry(
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
lr_scheduler = quant_config.lr_scheduler
use_quant_input = quant_config.use_quant_input
enable_quanted_input = quant_config.enable_quanted_input
enable_minmax_tuning = quant_config.enable_minmax_tuning
lr = quant_config.lr
minmax_lr = quant_config.minmax_lr
Expand All @@ -472,7 +472,7 @@ def autoround_quantize_entry(
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
Expand Down
Loading
Loading