Skip to content

Commit

Permalink
Migrate AutoRound to Torch new 3x API (#1763)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Apr 29, 2024
1 parent 044e6db commit e3c736f
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 167 deletions.
355 changes: 224 additions & 131 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,151 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import torch
from auto_round import AutoRound # pylint: disable=E0401
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
from auto_round.utils import get_block_names # pylint: disable=E0401

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger


class AutoRoundQuantizer(Quantizer):
def __init__(
self,
weight_config: dict = {},
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
device=None,
lr_scheduler=None,
use_quant_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
iters: int = 200,
seqlen: int = 2048,
n_samples: int = 512,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype="fp32",
):
"""Init a AutQRoundQuantizer object.
Args:
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
}
...
}
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
use_quant_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
"""

self.tokenizer = None
self.weight_config = weight_config
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.amp = amp
self.device = device
self.lr_scheduler = lr_scheduler
self.use_quant_input = use_quant_input
self.enable_minmax_tuning = enable_minmax_tuning
self.lr = lr
self.minmax_lr = minmax_lr
self.low_gpu_mem_usage = low_gpu_mem_usage
self.iters = iters
self.seqlen = seqlen
self.n_samples = n_samples
self.sampler = sampler
self.seed = seed
self.n_blocks = n_blocks
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = "int"
self.scale_dtype = scale_dtype

def prepare(self, model: torch.nn.Module, *args, **kwargs):
"""Prepares a given model for quantization.
Args:
model (torch.nn.Module): The model to be prepared.
Returns:
A prepared model.
"""
self.rounder = AutoRoundProcessor(
model=model,
tokenizer=None,
weight_config=self.weight_config,
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device=self.device,
lr_scheduler=self.lr_scheduler,
use_quant_input=self.use_quant_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
low_gpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
n_samples=self.n_samples,
sampler=self.sampler,
seed=self.seed,
n_blocks=self.n_blocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
data_type=self.data_type,
scale_dtype=self.scale_dtype,
)
self.rounder.prepare()
return model

def convert(self, model: torch.nn.Module, *args, **kwargs):
model, weight_config = self.rounder.convert()
model.autoround_config = weight_config
return model


@torch.no_grad()
def get_autoround_default_run_fn(
model,
Expand Down Expand Up @@ -94,140 +232,95 @@ def get_autoround_default_run_fn(
)


class InputCaptureModule(torch.nn.Module):
class AutoRoundProcessor(AutoRound):

def __init__(self) -> None:
super().__init__()
self.data_pairs = []
self.device = "cpu"
def prepare(self):
"""Prepares a given model for quantization."""
# logger.info("cache block input")
self.start_time = time.time()
self.block_names = get_block_names(self.model)
if len(self.block_names) == 0:
logger.warning("could not find blocks, exit with original model")
return
if self.amp:
self.model = self.model.to(self.amp_dtype)
if not self.low_gpu_mem_usage:
self.model = self.model.to(self.device)
# inputs = self.cache_block_input(block_names[0], self.n_samples)

def forward(self, *args, **kwargs):
if kwargs and len(args) == 0:
# Handle cases where input data is a dict
self.data_pairs.append(kwargs)
elif args and len(args) == 1:
# Handle cases where input data is a Tensor
self.data_pairs.append(args[0])
else:
logger.error("Handle cases where input data is neither a Tensor nor a dict")
# cache block input
self.inputs = {}
self.tmp_block_name = self.block_names[0]
self._replace_forward()

def convert(self):
"""Converts a prepared model to a quantized model."""
self._recover_forward()
inputs = self.inputs[self.tmp_block_name]
del self.tmp_block_name

def recover_dataloader_from_calib_fn(run_fn, run_args):
input_capture_model = InputCaptureModule()
input_capture_model.eval()
run_fn(input_capture_model, *run_args)
dataloader = torch.utils.data.DataLoader(input_capture_model.data_pairs)
return dataloader
del self.inputs
if "input_ids" in inputs.keys():
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
total_samples = inputs["input_ids"].shape[dim]
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.qdq_weight_round(
self.model,
inputs,
self.block_names,
n_blocks=self.n_blocks,
device=self.device,
)
for n, m in self.model.named_modules():
if n in self.weight_config.keys():
if hasattr(m, "scale"):
self.weight_config[n]["scale"] = m.scale
self.weight_config[n]["zp"] = m.zp
if self.group_size <= 0:
self.weight_config[n]["g_idx"] = torch.tensor(
[0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
else:
self.weight_config[n]["g_idx"] = torch.tensor(
[i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
)
delattr(m, "scale")
delattr(m, "zp")
else:
self.weight_config[n]["data_type"] = "float"
if self.amp_dtype == torch.bfloat16:
self.weight_config[n]["data_type"] = "bfloat"
self.weight_config[n]["bits"] = 16
self.weight_config[n]["group_size"] = None
self.weight_config[n]["sym"] = None

end_time = time.time()
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")
## dump a summary
quantized_layers = []
unquantized_layers = []
for n, m in self.model.named_modules():
if isinstance(m, tuple(self.supported_types)):
if self.weight_config[n]["bits"] == 16:
unquantized_layers.append(n)
else:
quantized_layers.append(n)
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
)
if len(unquantized_layers) > 0:
summary_info += f", {unquantized_layers} have not been quantized"

def autoround_quantize(
model,
weight_config: dict = {},
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
device=None,
lr_scheduler=None,
use_quant_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
iters: int = 200,
seqlen: int = 2048,
n_samples: int = 512,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype="fp16",
run_fn=None,
run_args=None,
):
"""The entry point of the autoround weight-only quantization.
Args:
model: The PyTorch model to be quantized.
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
}
...
}
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
use_quant_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
run_fn: a calibration function for calibrating the model. Defaults to None.
run_args: positional arguments for `run_fn`. Defaults to None.
Returns:
The quantized model.
"""
if run_fn is None or run_fn == get_autoround_default_run_fn:
assert run_args is not None, "Please provide tokenizer for AutoRound default calibration."
run_fn = get_autoround_default_run_fn
dataloader = recover_dataloader_from_calib_fn(run_fn, run_args)

rounder = AutoRound(
model=model,
tokenizer=None,
bits=4,
group_size=128,
sym=False,
weight_config=weight_config,
enable_full_range=enable_full_range, ##for symmetric, TODO support later
batch_size=batch_size,
amp=amp,
device=device,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
use_quant_input=use_quant_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
data_type="int",
scale_dtype=scale_dtype,
run_fn=run_fn,
run_args=run_args,
)
qdq_model, weight_config = rounder.quantize()
return qdq_model, weight_config
logger.info(summary_info)
if len(unquantized_layers) > 0:
logger.info(f"Summary: {unquantized_layers} have not been quantized")

self.quantized = True
self.model = self.model.to(self.model_orig_dtype)
return self.model, self.weight_config
Loading

0 comments on commit e3c736f

Please sign in to comment.