diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 5eb1414c..4f342797 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -23,8 +23,7 @@ from tqdm import tqdm import accelerate from packaging import version -from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer, \ - WrapperTransformerConv1d +from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer from .special_model_handler import ( shareable_keywords, init_cache_for_special_model, @@ -55,6 +54,7 @@ ) from .low_cpu_mem.utils import get_layers_before_block + class AutoRound(object): """For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023). @@ -108,6 +108,8 @@ class AutoRound(object): act_bits (int): Number of bits for activation quantization. Default is 16. act_group_size (int): Group size for activation quantization. Default is None. act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. to_quant_block_names (str|list): A string or list whose elements are list of block's layer names to be quantized. @@ -150,6 +152,7 @@ def __init__( act_bits: int = 16, act_group_size: int = None, act_sym: bool = None, + act_data_type: str = None, act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, @@ -188,12 +191,6 @@ def __init__( self.lr = lr or (1.0 / self.iters) ##must after iter setting self.minmax_lr = minmax_lr or self.lr - ##activation - self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size - self.act_bits = act_bits if not (act_bits is None) else self.bits - self.act_sym = act_sym if not (act_sym is None) else self.sym - self.act_dynamic = act_dynamic - self.data_type = data_type self.supported_types = [torch.nn.Linear, transformers.modeling_utils.Conv1D] self.model = model.eval() @@ -201,11 +198,18 @@ def __init__( self.device = detect_device(device) self.scale_dtype = convert_dtype_str2torch(scale_dtype) self.set_amp_dtype() + self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device if not hasattr(self, 'to_quant_block_names'): all_blocks = get_block_names(model) self.to_quant_block_names = find_matching_blocks(model, all_blocks, to_quant_block_names) - + + ##activation + self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size + self.act_bits = act_bits if not (act_bits is None) else self.bits + self.act_sym = act_sym if not (act_sym is None) else self.sym + self.act_dynamic = act_dynamic + self.act_data_type = act_data_type if act_data_type is not None else data_type self.sampler = sampler self.not_use_best_mse = not_use_best_mse @@ -215,15 +219,21 @@ def __init__( self.batch_dim = None self.infer_bs_coeff = 1 - self.set_layerwise_config(self.layer_config) ##better place in the end torch.set_printoptions(precision=3, sci_mode=True) self.check_configs() - logger.info(f"using {self.model.dtype} for quantization tuning") + if self.act_bits <= 8 and self.amp_dtype == torch.float16: + logger.warning("force to use bf16 to for quantization tuning when enabling activation quantization") + self.amp_dtype = torch.bfloat16 + self.model = self.model.to(torch.bfloat16) + else: + logger.info(f"using {self.model.dtype} for quantization tuning") self.enable_torch_compile = enable_torch_compile if is_optimum_habana_available(): logger.info("Optimum Habana is available, import htcore explicitly.") import habana_frameworks.torch.core as htcore # pylint: disable=E0401 - import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401 + import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + + self.set_layerwise_config(self.layer_config) ##better place in the end def check_configs(self): """Checks if the configurations are valid. @@ -242,12 +252,11 @@ def check_configs(self): assert self.seqlen > 0, "seqlen must be positive" assert self.nblocks > 0, "nblocks must be positive" assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive" - assert self.act_dynamic is True, "only support dynamic quantization for activation currently" # assert self.tokenizer != None or self.dataloader != None if self.act_bits <= 8: logger.warning( - "please save the quantized model to fake format " - "as real deployment is not supported for activation quantization currently") + "Activation quantization is an experimental feature with limited support and a complex API." + "And please save the quantized model to fake format as real deployment is not supported currently") if "mx_fp" in self.data_type: logger.warning( @@ -257,7 +266,6 @@ def check_configs(self): if "mx_fp" in self.data_type and self.group_size != 32: logger.warning("mx_fp should only support group_size of 32 in real deployment") - if self.nsamples < self.gradient_accumulate_steps * self.batch_size: self.batch_size = min(self.batch_size, self.nsamples) self.gradient_accumulate_steps = min(self.nsamples // self.batch_size, self.gradient_accumulate_steps) @@ -428,7 +436,7 @@ def set_layerwise_config(self, layer_config): """ layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.to_quant_block_names) keys = ["data_type", "bits", "group_size", "sym", "scale_dtype", "act_bits", "act_group_size", "act_sym", - "act_dynamic"] + "act_dynamic", "act_data_type"] for n, m in self.model.named_modules(): if not isinstance(m, tuple(self.supported_types)): continue @@ -452,7 +460,7 @@ def set_layerwise_config(self, layer_config): setattr(m, key, layer_config[n][key]) @torch.no_grad() - def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device): + def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, save_output=True): """Compute the output of a given block of the model for a given input. Args: @@ -483,10 +491,11 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( cache_device ) - if self.batch_size == 1: - output.append(tmp_output) - else: - output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) + if save_output: + if self.batch_size == 1: + output.append(tmp_output) + else: + output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) if self.low_gpu_mem_usage: clear_memory() @@ -665,7 +674,7 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n ## have bug if block name is not the first block if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: tmp_dtype = self.model.dtype - self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) + self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) ##model on cpu self.last_cache_name = last_cache_name if last_cache_name is None and len(block_names) + len(layer_names) == 1: @@ -852,17 +861,15 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp if q_inputs is not None: q_inputs[i] = q_inputs[i].to(layer.weight.dtype) - if isinstance(layer, torch.nn.Linear): - wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( - device) - else: - wrapper_linear = WrapperTransformerConv1d(layer, enable_minmax_tuning=self.enable_minmax_tuning, - device=device).to(device) + wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( + device) round_params = [] minmax_params = [] - round_params.append(wrapper_linear.value) - minmax_params.append(wrapper_linear.min_scale) - minmax_params.append(wrapper_linear.max_scale) + for key in wrapper_linear.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(wrapper_linear.params[key]) + else: + round_params.append(wrapper_linear.value) if self.enable_minmax_tuning: optimizer = self.optimizer( [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 @@ -959,6 +966,23 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" logger.debug(dump_info) + def register_act_max_hook(self, model): + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if not hasattr(module, "act_max"): + module.act_max = torch.abs(input).max().item() + else: + module.act_max = max(torch.abs(input).max().item(), module.act_max) + + hook_handles = [] + + for n, m in model.named_modules(): + if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")): """Quantize the weights of a given block of the model. @@ -972,9 +996,25 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch Returns: Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output) """ + if q_input is None: + hook_handles = self.register_act_max_hook(block) + + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + + for handle in hook_handles: + handle.remove() + else: + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + hook_handles = self.register_act_max_hook(block) + self.get_block_outputs(block, q_input, input_others, self.batch_size * self.infer_bs_coeff, + device, self.cache_device, save_output=False) - output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, - self.cache_device) + for handle in hook_handles: + handle.remove() if q_input is not None: input_ids = q_input @@ -986,13 +1026,11 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch minmax_params = [] for n, m in block.named_modules(): if hasattr(m, "orig_layer"): - if "v" in m.params.keys(): - round_params.append(m.params['v']) - if "max_scale" in m.params.keys(): - minmax_params.append(m.params["min_scale"]) - minmax_params.append(m.params["max_scale"]) - if "bias_v" in m.params.keys(): - round_params.append(m.params["bias_v"]) + for key in m.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(m.params[key]) + else: + round_params.append(m.params[key]) if self.enable_minmax_tuning: optimizer = self.optimizer( @@ -1069,6 +1107,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch total_loss += loss.item() / num_elm self.scale_loss_and_backward(scaler, loss) + if i == 0: init_loss = total_loss @@ -1430,6 +1469,8 @@ class AutoRoundOPT(AutoRound): act_bits (int): Number of bits for activation quantization. Default is 16. act_group_size (int): Group size for activation quantization. Default is None. act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. to_quant_block_names (str|list): A string or list whose elements are list of block's layer names to be quantized. @@ -1474,6 +1515,7 @@ def __init__( act_bits: int = 16, act_group_size: int = None, act_sym: bool = None, + act_data_type: str = None, act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, @@ -1513,6 +1555,7 @@ def __init__( act_bits=act_bits, act_group_size=act_group_size, act_sym=act_sym, + act_data_type=act_data_type, act_dynamic=act_dynamic, to_quant_block_names=to_quant_block_names, enable_norm_bias_tuning=enable_norm_bias_tuning, @@ -1601,6 +1644,8 @@ class AutoRoundAdam(AutoRoundOPT): act_bits (int): Number of bits for activation quantization. Default is 16. act_group_size (int): Group size for activation quantization. Default is None. act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. to_quant_block_names (str|list): A list whose elements are list of block's layer names to be quantized. enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning @@ -1642,6 +1687,7 @@ def __init__( act_bits: int = 16, act_group_size: int = None, act_sym: bool = None, + act_data_type: str = None, act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, @@ -1681,6 +1727,7 @@ def __init__( act_bits=act_bits, act_group_size=act_group_size, act_sym=act_sym, + act_data_type=act_data_type, act_dynamic=act_dynamic, to_quant_block_names=to_quant_block_names, enable_norm_bias_tuning=enable_norm_bias_tuning, @@ -1688,5 +1735,3 @@ def __init__( optimizer=optimizer, **kwargs, ) - - diff --git a/auto_round/data_type/__init__.py b/auto_round/data_type/__init__.py index 2a68bcdb..4425414f 100644 --- a/auto_round/data_type/__init__.py +++ b/auto_round/data_type/__init__.py @@ -14,5 +14,6 @@ import auto_round.data_type.int import auto_round.data_type.mxfp +import auto_round.data_type.fp8 from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE from auto_round.data_type.utils import get_quant_func diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py new file mode 100644 index 00000000..76da5e5c --- /dev/null +++ b/auto_round/data_type/fp8.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.data_type.register import register_dtype + + +def float8_e4m3fn_ste(x: torch.Tensor): + """Straight-Through Estimator (STE) for float8. + + Applies a quantization and dequantization step with float8 precision while maintaining + gradient flow using a straight-through estimator. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Quantized and dequantized tensor using float8 format. + """ + fp8 = (x.to(torch.float8_e4m3fn).to(x.dtype) - x).detach() + x + + return fp8 + + +@register_dtype("fp8_dynamic_per_token_sym") +def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs): + """Dynamic per-token symmetric quantization using float8. + + This function dynamically calculates a per-token scaling factor for each group of tokens + and applies symmetric quantization using float8 format. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + info = torch.finfo(torch.float8_e4m3fn) + orig_dtype = tensor.dtype + + tensor = tensor.reshape(-1, orig_shape[-1]) + max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + 0] * max_scale + + scale = max_tensor.to(torch.float32) / info.max + min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm + scale = torch.clip(scale, min=min_scaling_factor) + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + scale = scale.unsqueeze(dim=-1) + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, None + + +@register_dtype("fp8_sym") +def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, **kwargs): + """Symmetric quantization using float8 format. + + Allows both dynamic per-token scaling and tensor-wide quantization depending on input. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + tensor_max (float, optional): Maximum tensor value for precomputed scale. Defaults to None. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + info = torch.finfo(torch.float8_e4m3fn) + orig_dtype = tensor.dtype + + if tensor_max is None: ##dynamic per-te + tensor = tensor.reshape(-1, orig_shape[-1]) + max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + 0] * max_scale + else: + max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + scale = max_tensor.to(torch.float32) / info.max + min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm + scale = torch.clip(scale, min=min_scaling_factor) + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + scale = scale.unsqueeze(dim=-1) + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, None + + +@register_dtype("fp8_to_int_sym") +def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5, + weight_fp8_max_scale=1.0, **kwargs): + """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 + + This method first quantizes the input tensor into float8 format and then performs + a secondary quantization to int4 with grouping. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + bits (int, optional): Bit precision for secondary quantization. Defaults to 4. + group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). + v (float, optional): Optional parameter for variance tuning. Defaults to 0. + min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. + max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. + q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. + weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Combined scaling factor (torch.Tensor). + - Placeholder for zp (None). + """ + + info = torch.finfo(torch.float8_e4m3fn) + tensor_max = torch.max(torch.abs(tensor)).to(torch.float32) * weight_fp8_max_scale ## better train a ratio + scale = tensor_max.to(torch.float32) / info.max + min_scaling_factor = 1.0 / (info.max * 512.0) ##copy from vllm + scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) + fp8_res = tensor / scale_bf16_to_fp8 + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + + ##convert to bf16 + fp8_res_using_16bit = fp8_res.to(tensor.dtype) + ##convert to int4 + from auto_round.data_type.int import quant_tensor_sym + qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(fp8_res_using_16bit, bits=bits, + group_size=group_size, v=v, + min_scale=min_scale, + max_scale=max_scale, + scale_dtype=torch.bfloat16, + q_scale_thresh=q_scale_thresh) + qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 + + return qdq_tensor, scale_fp8_to_int4 * scale_bf16_to_fp8, None, diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 6dc00c15..94c9f9d5 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -13,79 +13,84 @@ # limitations under the License. import torch -from .utils import round_ste +from .utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad from auto_round.data_type.register import register_dtype @register_dtype("int_sym") -def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, - weight_max=None, q_scale_thresh=0.0, **kwargs): - """Quantize and de-quantize weight asymmetrically. full range, credict goes to llamacpp community +def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community Args: - weight: Tensor containing the weight to be quantized + tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization v: Rounding value perturbation - min_scale: Minimum scale coefficient for weight - max_scale: Maximum scale coefficient for weight - weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. - weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability Returns: - Quantized and de-quantized weight, scale, zero-point + Quantized and de-quantized tensor, scale, zero-point """ + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = torch.tensor(2 ** (bits - 1)) - if weight_min is None or weight_max is None: - wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) - wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) else: - wmin_tmp = weight_min - wmax_tmp = weight_max + wmin_tmp = tensor_min + wmax_tmp = tensor_max - wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130 + wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130 wmax_abs = wmax_tmp * max_scale - max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs) - scale = (max_v / maxq).to(scale_dtype) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) zp = torch.full_like(scale, maxq) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) - int_w = round_ste(weight / scale + v) + int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) - qdq_result = (scale * (q - zp)).to(weight.dtype) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp @register_dtype("int_asym") -def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, - weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): - """Quantize and de-quantize weight asymmetrically. +def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. Args: - weight: Tensor containing the weight to be quantized + tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization v: Rounding value perturbation - min_scale: Minimum scale coefficient for weight - max_scale: Maximum scale coefficient for weight - weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. - weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability Returns: - Quantized and de-quantized weight, scale, zero-point + Quantized and de-quantized tensor, scale, zero-point """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = torch.tensor(2 ** bits - 1) - if weight_min is None or weight_max is None: - wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) - wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) else: - wmin_tmp = weight_min - wmax_tmp = weight_max + wmin_tmp = tensor_min + wmax_tmp = tensor_max if isinstance(min_scale, torch.Tensor): wmin = wmin_tmp * min_scale wmax = wmax_tmp * max_scale @@ -97,38 +102,42 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d zp = round_ste(-wmin / scale) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) - int_w = round_ste(weight / scale + v) + int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w + zp, 0, maxq) - qdq_result = (scale * (q - zp)).to(weight.dtype) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp @register_dtype("int_sym_gptq") -def quant_tensor_sym_gptq(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, - weight_max=None, q_scale_thresh=0.0, **kwargs): - """Quantize and de-quantize weight asymmetrically. +def quant_tensor_sym_gptq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. Args: - weight: Tensor containing the weight to be quantized + tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization v: Rounding value perturbation - min_scale: Minimum scale coefficient for weight - max_scale: Maximum scale coefficient for weight - weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. - weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability Returns: - Quantized and de-quantized weight, scale, zero-point + Quantized and de-quantized tensor, scale, zero-point """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = torch.tensor(2 ** bits - 1) - if weight_min is None or weight_max is None: - wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) - wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) else: - wmin_tmp = weight_min - wmax_tmp = weight_max + wmin_tmp = tensor_min + wmax_tmp = tensor_max if isinstance(min_scale, torch.Tensor): wmin = wmin_tmp * min_scale wmax = wmax_tmp * max_scale @@ -147,37 +156,41 @@ def quant_tensor_sym_gptq(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, sca scale = scale.unsqueeze(dim=-1) zp = torch.full_like(scale, (maxq + 1) / 2) - int_w = round_ste(weight / scale + v) + int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w + zp, 0, maxq) - qdq_result = (scale * (q - zp)).to(weight.dtype) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp -def quant_tensor_asym_wo_round(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, - weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): - """Quantize and de-quantize weight asymmetrically without rounding, this is mainly for tuning bias, norm. +def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, + scale_dtype=torch.float16, + tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically without rounding, this is mainly for tuning bias, norm. Args: - weight: Tensor containing the weight to be quantized + tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization v: Rounding value perturbation - min_scale: Minimum scale coefficient for weight - max_scale: Maximum scale coefficient for weight - weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. - weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability Returns: - Quantized and de-quantize weight, scale, zero-point + Quantized and de-quantize tensor, scale, zero-point """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = torch.tensor(2 ** bits - 1) - if weight_min is None or weight_max is None: - wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) - wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) else: - wmin_tmp = weight_min - wmax_tmp = weight_max + wmin_tmp = tensor_min + wmax_tmp = tensor_max if isinstance(min_scale, torch.Tensor): wmin = wmin_tmp * min_scale wmax = wmax_tmp * max_scale @@ -190,7 +203,8 @@ def quant_tensor_asym_wo_round(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0 zp = -wmin / scale # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) - int_w = weight / scale + v + int_w = tensor / scale + v q = torch.clamp(int_w + zp, 0, maxq) - qdq_result = (scale * (q - zp)).to(weight.dtype) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index 60229218..f558ac94 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from .utils import floor_ste, round_ste +from .utils import floor_ste, round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad from auto_round.data_type.register import register_dtype, QUANT_FUNC_WITH_DTYPE MXFP_FORMAT_CACHE = { @@ -38,7 +38,8 @@ FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) -def quant_mx(tensor, bits, data_type, v, max_scale, mantissa_rounding="even", **kwargs): +def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, + mantissa_rounding="even", data_type="mx_fp", **kwargs): """Quantize the given tensor using the specified parameters. This function performs quantization on the `tensor` tensor according to the @@ -49,6 +50,7 @@ def quant_mx(tensor, bits, data_type, v, max_scale, mantissa_rounding="even", ** Args: tensor (torch.Tensor): The tensor containing the tensors to be quantized. bits (int): The bit width to be used for quantization. + group_size (int): The group size of sharing scale and exponent. data_type (str): The data type for quantization (e.g., 'mx_fp4'). v (float): A value used for adjusting the tensors. max_scale (float or torch.Tensor): The maximum scale to be applied to the tensors. @@ -60,6 +62,7 @@ def quant_mx(tensor, bits, data_type, v, max_scale, mantissa_rounding="even", ** Raises: KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`. """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] orig_dtype = tensor.dtype shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) @@ -113,6 +116,7 @@ def quant_mx(tensor, bits, data_type, v, max_scale, mantissa_rounding="even", ** tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) tensor = tensor * (2 ** shared_exp) + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index 594a7815..4d774c88 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -16,6 +16,63 @@ from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE +def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int): + """Reshapes and pads the tensor to ensure that it can be quantized in groups of `group_size`. + + This function adjusts t + he input tensor's shape so that its last dimension is a multiple + of the specified `group_size`. If padding is required, it adds padding to the tensor + to achieve this. If the tensor's last dimension is already divisible by `group_size`, + no padding is applied. + + Args: + data (torch.Tensor): The input tensor to be reshaped and padded. + group_size (int): The size of the groups that the tensor should be reshaped into. + + Returns: + torch.Tensor: The reshaped and padded tensor, if necessary. + tuple: The original shape of the input tensor. + int: The padding length applied to the tensor. Returns 0 if no padding is applied. + """ + orig_shape = data.shape + pad_len = 0 + if len(data.shape) > 2: + data = data.reshape(-1, orig_shape[-1]) + if group_size == -1 or data.shape[1] < group_size: + return data, orig_shape, pad_len + elif data.shape[1] % group_size == 0: + data = data.reshape(-1, group_size) + return data, orig_shape, pad_len + else: + pad_len = (data.shape[1] + group_size - 1) // group_size * group_size - data.shape[1] + data_new = torch.nn.functional.pad(data, (0, pad_len)) + data_new = data_new.reshape(-1, group_size) + return data_new, orig_shape, pad_len + + +def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int): + """Reverts the tensor to its original shape by removing padding. + + This function removes the padding added during reshaping and returns the tensor to + its original shape. + + Args: + data (torch.Tensor): The reshaped and possibly padded tensor. + orig_shape (tuple): The original shape of the tensor before reshaping. + pad_len (int): The length of the padding to be removed. + + Returns: + torch.Tensor: The tensor restored to its original shape. + """ + if pad_len == 0: + return data.reshape(orig_shape) + else: + data_new = data.reshape(data.shape[0], -1) + data_new = data_new[:, :-pad_len] + data_new = data_new.reshape(orig_shape) + return data_new + + def get_quant_func(dtype, bits, sym): """Retrieve the quantization function based on data type, bit width, and symmetry. @@ -52,6 +109,14 @@ def get_quant_func(dtype, bits, sym): if key in QUANT_FUNC_WITH_DTYPE.keys(): return QUANT_FUNC_WITH_DTYPE[key], key + if sym: + key = dtype + "_sym" + else: + key = dtype + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + if sym: key = dtype + str(bits) else: diff --git a/auto_round/quantizer.py b/auto_round/quantizer.py index 97a71446..90d00493 100644 --- a/auto_round/quantizer.py +++ b/auto_round/quantizer.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from torch.functional import F import transformers from auto_round.data_type import get_quant_func from .utils import ( @@ -23,7 +24,7 @@ ) -def reshape_tensor(v, group_size=-1): +def reshape_and_pad_tensor(v, group_size=-1): """Reshapes the tensor based on the group size. Args: @@ -44,56 +45,270 @@ def reshape_tensor(v, group_size=-1): return v -def quant_tensor( - quant_func, data, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, - weight_min=None, weight_max=None, q_scale_thresh=1e-5, **kwargs, -): - """Quantizes and dequantizes weight, handing the group size issue . +class WrapperLinear(torch.nn.Module): + """A wrapper for linear/conv1d layers to enable quantization and tuning. - Args: - data: Tensor containing the weight to be quantized - bits: Number of bits for quantization (e.g., 2, 3, 4, 8) - group_size: The number of elements shares scale and zero point - sym: Sym or asym - v: Rounding value perturbation - min_scale: Minimum scale coefficient for weight - max_scale: Maximum scale coefficient for weight - weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. - weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + This module wraps an existing linear or conv1d layer and provides additional functionality + for quantization, parameter tuning, and activation/bias normalization. - Returns: - Quantized and dequantized weight, scale, zero-point + Args: + orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d). + enable_minmax_tuning (bool): Whether to enable min-max scale tuning. + enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term. + device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). """ - orig_shape = data.shape - if len(data.shape) > 2: - data = data.reshape(-1, orig_shape[-1]) - if group_size == -1 or data.shape[1] < group_size: - data, scale, zp = quant_func(data, bits, v=v, min_scale=min_scale, max_scale=max_scale, - scale_dtype=scale_dtype, weight_min=weight_min, weight_max=weight_max, - q_scale_thresh=q_scale_thresh, **kwargs) - data = data.reshape(orig_shape) - return data, scale, zp - - if data.shape[1] % group_size == 0: - data = data.reshape(-1, group_size) - data, scale, zp = quant_func(data, bits, v=v, min_scale=min_scale, max_scale=max_scale, - scale_dtype=scale_dtype, weight_min=weight_min, weight_max=weight_max, - q_scale_thresh=q_scale_thresh, **kwargs) - data = data.reshape(orig_shape) - return data, scale, zp - else: - tmp_shape = data.shape - pad_len = (data.shape[1] + group_size - 1) // group_size * group_size - data.shape[1] - data_new = torch.nn.functional.pad(data, (0, pad_len)) - data_new = data_new.reshape(-1, group_size) - data_new, scale, zp = quant_func(data_new, bits, v=v, min_scale=min_scale, - max_scale=max_scale, scale_dtype=scale_dtype, weight_min=weight_min, - weight_max=weight_max, q_scale_thresh=q_scale_thresh, **kwargs) - data_new = data_new.reshape(tmp_shape[0], -1) - data_new = data_new[:, :-pad_len] - data_new = data_new.reshape(orig_shape) - return data_new, scale, zp + def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu'): + """Initializes the WrapperLinear module. + + Args: + orig_layer (torch.nn.Module): The original layer to wrap. + enable_minmax_tuning (bool): Whether to enable min-max scale tuning. + enable_norm_bias_tuning (bool): Whether to enable normalization and tuning for the bias term. + device (str): The computation device, such as 'cpu' or 'cuda'. + """ + super(WrapperLinear, self).__init__() + self.orig_layer = orig_layer + self.device = device + self.enable_minmax_tuning = enable_minmax_tuning + self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None) + self.enable_act_quant = self.orig_layer.act_bits <= 8 + self.q_scale_thresh = 1e-5 + self._init_tuning_params_and_quant_func() + self.orig_forward = self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward + + def _init_tuning_params_and_quant_func(self): + """Initializes tuning parameters and quantization functions. + + This method sets up required parameters and functions for weight quantization, + activation quantization, and bias/normalization. + """ + self.params = {} + p_dtype = torch.float32 ##parameter dtype + + orig_layer = self.orig_layer + orig_weight = getattr(orig_layer, "get_weight", lambda: orig_layer.weight)() + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + orig_weight = orig_weight.t() + weight_reshape = reshape_and_pad_tensor(orig_weight.data, orig_layer.group_size) + self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) + self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) + self._init_params("value", p_dtype, weight_reshape.shape, 0, True) + + # Min-max scale initialization + shape = get_scale_shape(orig_weight, orig_layer.group_size) + self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) + self._init_params("max_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) + + self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, + orig_layer.sym) + + if self.enable_act_quant: + self.act_quant_func, self.act_data_type = get_quant_func(orig_layer.act_data_type, + orig_layer.act_bits, + orig_layer.act_sym) + self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic) + + ## bias tuning + if self.enable_norm_bias_tuning: + self._init_params("bias_v", p_dtype, self.orig_layer.bias.shape, 0, True) + from auto_round.data_type.int import quant_tensor_asym_wo_round + self.bias_quant_func = quant_tensor_asym_wo_round + self.params["bias_v"] = self.bias_v + + def _init_params(self, name, dtype, shape, value, bool_condition): + """Initializes a parameter for tuning or uses a constant if tuning is disabled. + + Args: + name (str): Name of the parameter. + dtype (torch.dtype): Data type of the parameter. + shape (tuple): Shape of the parameter. + value (float): Initial value for the parameter. + bool_condition (bool): Whether the parameter should be tunable. + """ + if bool_condition: + p = torch.nn.Parameter(torch.ones(shape, device=self.device, dtype=dtype) * value, requires_grad=True) + self.params.update({name: p}) + else: + p = torch.tensor(1.0 * value, device=self.device, dtype=dtype) + + setattr(self, name, p) + + def _qdq_weight(self, value, min_scale, max_scale): + """Quantizes and dequantizes weights with tuning parameters. + + Args: + value (torch.Tensor): Value added for rounding for tuning. + min_scale (torch.Tensor): Minimum scale for the min value of quantization. + max_scale (torch.Tensor): Maximum scale for the max value of quantization. + + Returns: + tuple: Quantized weight, scale, and zero point. + """ + min_scale.data.clamp_(0, 1.0) + max_scale.data.clamp_(0, 1.0) + weight = self.orig_layer.weight + if weight.device.type == 'meta': + weight = self.orig_layer.get_weight().to(self.device) + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + weight = weight.t() + + weight_q, scale, zp = self.weight_quant_func(weight, bits=self.orig_layer.bits, + group_size=self.orig_layer.group_size, v=value, + min_scale=min_scale, max_scale=max_scale, + scale_dtype=self.orig_layer.scale_dtype, + tensor_min=self.weight_min, tensor_max=self.weight_max, + data_type=self.data_type, q_scale_thresh=self.q_scale_thresh) + weight_q = weight_q.to(weight.dtype) + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + weight_q = weight_q.t() + return weight_q, scale, zp + + def _qdq_act(self, x, act_max_scale, act_max=None): + """Quantizes and dequantizes activations. + + Args: + x (torch.Tensor): Input activations. + act_max_scale (torch.Tensor): Maximum scale for the act_max + act_max (torch.Tensor, optional): Maximum value for activation quantization. Defaults to None. + + Returns: + tuple: Quantized activation, scale, and zero point. + """ + act_max_scale.data.clamp_(0, 1.0) + x, scale, zp = self.act_quant_func(x, bits=self.orig_layer.act_bits, group_size=self.orig_layer.act_group_size, + scale_dtype=self.orig_layer.scale_dtype, q_scale_thresh=self.q_scale_thresh, + data_type=self.act_data_type, max_scale=act_max_scale, tensor_max=act_max) + return x, scale, zp + + def _qdq_bias(self, bias, bias_v): + """Quantizes and dequantizes bias. + + Args: + bias (torch.Tensor): Bias tensor to be quantized. + bias_v (torch.Tensor): Value added for rounding for tuning. + + Returns: + tuple: Quantized bias, scale, and zero point. + """ + bias_bits = 4 ## hard code + bias_group_size = -1 + bias, scale, zp = self.bias_quant_func(bias, bits=bias_bits, group_size=bias_group_size, v=bias_v, + q_scale_thresh=self.q_scale_thresh) + return bias, scale, zp + + def unwrapper(self, best_params): + """Restores the original layer by applying the best tuning parameters. + + Args: + best_params (dict): Dictionary containing the best tuning parameters. + + Returns: + torch.nn.Module: The unwrapped and restored original layer. + """ + best_params = best_params or {} + v = best_params.get('value', torch.tensor(0.0)).to(self.device) + min_scale = best_params.get('min_scale', torch.tensor(1.0)).to(self.device) + max_scale = best_params.get('max_scale', torch.tensor(1.0)).to(self.device) + + if self.orig_layer.weight.device.type == 'meta': + self.orig_layer.to(self.device) + ##unwrapper weight + qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) + self.orig_layer.weight.data.copy_(qdq_weight) + self.orig_layer.weight.grad = None + + shape = qdq_weight.shape + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + shape = qdq_weight.t().shape + scale = scale.reshape(shape[0], -1) + if zp is not None: + zp = zp.reshape(shape[0], -1) + + self.orig_layer.scale = scale.to("cpu") + self.orig_layer.zp = zp.to("cpu") if zp is not None else None + + ##unwrapper bias + if self.enable_norm_bias_tuning and "bias_v" in best_params.keys(): ##fake quant + bias_v = best_params["bias_v"].to(self.device) + bias = self.orig_layer.bias + if bias is not None and bias.device.type == 'meta': + bias = self.orig_layer.get_bias().to(self.device) + bias, _, _ = self._qdq_bias(bias, bias_v) + self.orig_layer.bias.grad = None + self.orig_layer.bias.data.copy_(bias) + + if hasattr(self.orig_layer, 'update'): + self.orig_layer.update() + self.orig_layer.to('meta') + + ##unwrapper act + if self.enable_act_quant: + act_max_scale = best_params.get('act_max_scale', torch.tensor(1.0)).to(self.device) + self.orig_layer.q_scale_thresh = self.q_scale_thresh + self.orig_layer.data_type = self.data_type + if not self.orig_layer.act_dynamic: + self.orig_layer.act_max = self.orig_layer.act_max * act_max_scale.item() + self.orig_layer.act_data_type = self.act_data_type + self.orig_layer.act_quant_func = self.act_quant_func + wrapper_layer = WrapperWALayer(self.orig_layer) + return wrapper_layer + + return self.orig_layer + + def linear_forward(self, x, weight, bias): + """Performs the forward pass for a linear layer. + + Args: + x (torch.Tensor): Input tensor. + weight (torch.Tensor): Weight tensor for the linear layer. + bias (torch.Tensor): Bias tensor for the linear layer. + + Returns: + torch.Tensor: Output tensor after applying the linear layer. + """ + return F.linear(x, weight, bias) + + def conv1d_forward(self, x, weight, bias): + """Performs the forward pass for a Conv1D layer. + + Args: + x (torch.Tensor): Input tensor. + weight (torch.Tensor): Weight tensor for the Conv1D layer. + bias (torch.Tensor): Bias tensor for the Conv1D layer. + + Returns: + torch.Tensor: Output tensor after applying the Conv1D layer. + """ + size_out = x.size()[:-1] + (self.orig_layer.nf,) + x = torch.addmm(bias, x.view(-1, x.size(-1)), weight) + x = x.view(*size_out) + return x + + def forward(self, x): + """Executes the forward pass with quantized weights and optional bias/activation quantization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the wrapped layer. + """ + weight_q, _, _ = self._qdq_weight(self.value, self.min_scale, self.max_scale) + + if self.enable_act_quant: + act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + x, _, _ = self._qdq_act(x, act_max_scale=self.act_max_scale, act_max=act_max) + + # pylint: disable=not-callable + bias = self.orig_layer.bias + if bias is not None and bias.device.type == 'meta': + bias = self.orig_layer.get_bias().to(self.device) + + if self.enable_norm_bias_tuning: + bias, _, _ = self._qdq_bias(bias, self.bias_v) + + return self.orig_forward(x, weight_q, bias) class WrapperWALayer(torch.nn.Module): @@ -103,11 +318,13 @@ def __init__(self, orig_layer): self.act_quant_func = self.orig_layer.act_quant_func def forward(self, x): - x, _, _ = quant_tensor(self.orig_layer.act_quant_func, x, self.orig_layer.act_bits, - self.orig_layer.group_size, - scale_dtype=self.orig_layer.scale_dtype, - q_scale_thresh=self.orig_layer.q_scale_thresh, - data_type=self.orig_layer.act_data_type) + tensor_max = self.orig_layer.tensor_max if hasattr(self.orig_layer, "tensor_max") else None + x, _, _ = self.orig_layer.act_quant_func(x, bits=self.orig_layer.act_bits, + group_size=self.orig_layer.group_size, + scale_dtype=self.orig_layer.scale_dtype, + q_scale_thresh=self.orig_layer.q_scale_thresh, + data_type=self.orig_layer.act_data_type, + tensor_max=tensor_max) return self.orig_layer.forward(x) @@ -128,7 +345,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): weight_dtype = torch.float32 self.q_scale_thresh = 1e-5 self.v = torch.nn.Parameter( - reshape_tensor( + reshape_and_pad_tensor( torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), self.group_size), requires_grad=True) @@ -140,15 +357,15 @@ def unwrapper(self, best_params): if best_params is None: return self.orig_layer v = best_params['v'] - weight_q, _, _ = quant_tensor(self.quant_func, self.orig_layer.weight, self.bits, self.group_size, - v, q_scale_thresh=self.q_scale_thresh) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + v, q_scale_thresh=self.q_scale_thresh) self.orig_layer.q_scale_thresh = self.q_scale_thresh self.orig_layer.weight.data.copy_(weight_q) return self.orig_layer def forward(self, input): - weight_q, _, _ = quant_tensor(self.quant_func, self.orig_layer.weight, self.bits, self.group_size, - self.v, q_scale_thresh=self.q_scale_thresh) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + self.v, q_scale_thresh=self.q_scale_thresh) import torch.nn.functional as F return F.layer_norm( input, self.orig_layer.normalized_shape, weight_q, self.orig_layer.bias, self.orig_layer.eps) @@ -171,7 +388,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): weight_dtype = torch.float32 self.q_scale_thresh = 1e-5 self.v = torch.nn.Parameter( - reshape_tensor( + reshape_and_pad_tensor( torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), self.group_size), requires_grad=True) @@ -183,15 +400,15 @@ def unwrapper(self, best_params): if best_params is None: return self.orig_layer v = best_params['v'] - weight_q, _, _ = quant_tensor(self.quant_func, self.orig_layer.weight, self.bits, self.group_size, - v, q_scale_thresh=self.q_scale_thresh) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + v, q_scale_thresh=self.q_scale_thresh) self.orig_layer.q_scale_thresh = self.q_scale_thresh self.orig_layer.weight.data.copy_(weight_q) return self.orig_layer def forward(self, hidden_states): - weight_q, _, _ = quant_tensor(self.quant_func, self.orig_layer.weight, self.bits, self.group_size, - self.v, q_scale_thresh=self.q_scale_thresh) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + self.v, q_scale_thresh=self.q_scale_thresh) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -207,360 +424,6 @@ def forward(self, hidden_states): norm_mapping["MistralRMSNorm"] = WrapperLlamaNorm -class WrapperLinear(torch.nn.Module): - def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu'): - """A wrapper module for linear layers that enables quantization and min-max tuning of weights. - - Args: - - orig_layer (torch.nn.Module): The original linear layer to be wrapped. - - enable_minmax_tuning (bool): Whether to enable min-max scaling tuning. Default is True. - - Attributes: - - orig_layer (torch.nn.Module): The original linear layer being wrapped. - - bits (int): The number of bits for quantization. - - group_size (int): The size of the groups for quantization. - - sym (bool): Whether the symmetric quantization is to be used. - - value (torch.nn.Parameter): The learnable parameter for quantization. - - enable_minmax_tuning (bool): Whether min-max scaling tuning is enabled. - - min_scale (torch.nn.Parameter or torch.Tensor): The minimum scale for min-max tuning. - - max_scale (torch.nn.Parameter or torch.Tensor): The maximum scale for min-max tuning. - """ - super(WrapperLinear, self).__init__() - self.orig_layer = orig_layer - self.device = device - self.bits = self.orig_layer.bits - self.group_size = self.orig_layer.group_size - self.scale_dtype = self.orig_layer.scale_dtype - self.sym = self.orig_layer.sym - self.data_type = self.orig_layer.data_type - self.weight_quant_func, self.data_type = get_quant_func(self.orig_layer.data_type, self.bits, self.sym) - self.act_bits = self.orig_layer.act_bits - self.act_group_size = self.orig_layer.act_group_size - self.act_sym = self.orig_layer.act_sym - self.act_dynamic = self.orig_layer.act_dynamic - self.act_quant = self.act_bits <= 8 - self.params = {} - - if self.act_quant: - self.act_quant_func, self.act_data_type = get_quant_func(self.orig_layer.data_type, self.act_bits, - self.act_sym) - - self.q_scale_thresh = 1e-5 - - weight_dtype = torch.float32 - orig_layer_weight = self.orig_layer.weight if not hasattr(self.orig_layer, 'get_weight') \ - else self.orig_layer.get_weight() - self.value = torch.nn.Parameter( - reshape_tensor( - torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), - self.group_size), - requires_grad=True) - self.params["v"] = self.value - weight_reshape = reshape_tensor(orig_layer_weight.data, self.group_size) - self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) - self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) - - self.enable_minmax_tuning = enable_minmax_tuning - shape = get_scale_shape(self.orig_layer.weight, self.group_size) - if self.enable_minmax_tuning: - self.min_scale = torch.nn.Parameter( - torch.ones(shape, device=self.device, dtype=weight_dtype), requires_grad=True - ) - self.max_scale = torch.nn.Parameter( - torch.ones(shape, device=self.device, dtype=weight_dtype), requires_grad=True - ) - self.params["min_scale"] = self.min_scale - self.params["max_scale"] = self.max_scale - else: - self.min_scale = torch.tensor(1.0, device=self.device, dtype=weight_dtype) - self.max_scale = torch.tensor(1.0, device=self.device, dtype=weight_dtype) - self.enable_norm_bias_tuning = False - if enable_norm_bias_tuning and self.orig_layer.bias is not None: - self.enable_norm_bias_tuning = True - self.bias_bits = 4 ## hard code - self.bias_group_size = -1 - self.bias_v = torch.nn.Parameter( - reshape_tensor( - torch.zeros(self.orig_layer.bias.shape, device=self.device, dtype=weight_dtype), - self.bias_group_size), - requires_grad=True) - from auto_round.data_type.int import quant_tensor_asym_wo_round - self.bias_quant_func = quant_tensor_asym_wo_round - self.params["bias_v"] = self.bias_v - - def unwrapper(self, best_params): - """Unwrapper the layer to the original layer. - - Args: - - v (torch.Tensor): The rounding v parameter for quantization. - - min_scale (torch.nn.Parameter or torch.Tensor): The minimum scale for min-max tuning. - - max_scale (torch.nn.Parameter or torch.Tensor): The maximum scale for min-max tuning. - - Returns: - - torch.nn.Module: The original linear layer with updated weights after quantization and dequantization. - """ - min_scale = torch.tensor(1.0) - max_scale = torch.tensor(1.0) - v = torch.tensor(0.0) - if best_params is not None: - min_scale = best_params.get('min_scale', min_scale) - max_scale = best_params.get('max_scale', max_scale) - v = best_params.get('v', v) - - min_scale.clamp_(0, 1.0) - max_scale.clamp_(0, 1.0) - v = v.to(self.device) - min_scale = min_scale.to(self.device) - max_scale = max_scale.to(self.device) - - if self.orig_layer.weight.device.type == 'meta': - self.orig_layer.to(self.device) - qdq_weight, scale, zp = quant_tensor(self.weight_quant_func, self.orig_layer.weight, self.bits, - self.group_size, v, - min_scale, max_scale, self.scale_dtype, self.weight_min, self.weight_max, - data_type=self.data_type, q_scale_thresh=self.q_scale_thresh) - scale = scale.reshape(qdq_weight.shape[0], -1) - if zp is not None: - zp = zp.reshape(qdq_weight.shape[0], -1) - - self.orig_layer.weight.data.copy_(qdq_weight) - self.orig_layer.weight.grad = None - self.orig_layer.scale = scale.to("cpu") - self.orig_layer.zp = zp.to("cpu") if zp is not None else None - if self.enable_norm_bias_tuning and "bias_v" in best_params.keys(): ##fake quant - bias_v = best_params["bias_v"] - bias, _, _ = quant_tensor(self.bias_quant_func, self.orig_layer.bias, self.bias_bits, self.bias_group_size, - bias_v, q_scale_thresh=self.q_scale_thresh) - self.orig_layer.bias.grad = None - self.orig_layer.bias.data.copy_(bias) - - self.orig_layer.q_scale_thresh = self.q_scale_thresh - self.orig_layer.data_type = self.data_type - if self.act_quant: - self.orig_layer.act_data_type = self.act_data_type - self.orig_layer.act_quant_func = self.act_quant_func - wrapper_layer = WrapperWALayer(self.orig_layer) - return wrapper_layer - - if hasattr(self.orig_layer, 'update'): - self.orig_layer.update() - self.orig_layer.to('meta') - - return self.orig_layer - - def forward(self, x): - """Performs forward pass through the wrapped linear layer with quantized weights. - - Args: - - x (torch.Tensor): The input tensor. - - Returns: - - torch.Tensor: The output tensor after applying the linear transformation with quantized weights. - """ - from torch.functional import F - - weight = self.orig_layer.weight - if weight.device.type == 'meta': - weight = self.orig_layer.get_weight().to(self.device) - self.min_scale.data.copy_(torch.clamp(self.min_scale.data, 0, 1.0)) - self.max_scale.data.copy_(torch.clamp(self.max_scale.data, 0, 1.0)) - weight_q, _, _ = quant_tensor(self.weight_quant_func, weight, self.bits, self.group_size, self.value, - self.min_scale, - self.max_scale, self.scale_dtype, self.weight_min, self.weight_max, - data_type=self.data_type, q_scale_thresh=self.q_scale_thresh) - weight_q = weight_q.to(weight.dtype) - if self.act_quant: - x, _, _ = quant_tensor(self.act_quant_func, x, self.act_bits, self.act_group_size, - scale_dtype=self.scale_dtype, q_scale_thresh=self.q_scale_thresh, - data_type=self.act_data_type) - # pylint: disable=not-callable - bias = self.orig_layer.bias - if bias is not None and bias.device.type == 'meta': - bias = self.orig_layer.get_bias().to(self.device) - if self.enable_norm_bias_tuning: - bias, _, _ = quant_tensor(self.bias_quant_func, bias, self.bias_bits, self.bias_group_size, self.bias_v, - q_scale_thresh=self.q_scale_thresh) - - return F.linear(x, weight_q, bias) - - -class WrapperTransformerConv1d(torch.nn.Module): - def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu'): - """A wrapper module for transformers 1D convolutional layers used in transformers, - enabling quantization and min-max tuning of weights. - - Args: - - orig_layer (torch.nn.Module): The original 1D convolutional layer to be wrapped. - - bits (int): The number of bits for quantization. - - group_size (int): The size of the groups for quantization. - - sym (bool): Whether symmetric quantization is to be used. - - enable_minmax_tuning (bool): Whether to enable min-max scaling tuning. Default is True. - - Attributes: - - orig_layer (torch.nn.Module): The original 1D convolutional layer being wrapped. - - bits (int): The number of bits for quantization. - - group_size (int): The size of the groups for quantization. - - sym (bool): Whether symmetric quantization is to be used. - - weight_t (torch.Tensor): Transposed weight tensor of the original layer. - - value (torch.nn.Parameter): The learnable parameter for quantization. - - enable_minmax_tuning (bool): Whether min-max scaling tuning is enabled. - - min_scale (torch.nn.Parameter or torch.Tensor): The minimum scale for min-max tuning. - - max_scale (torch.nn.Parameter or torch.Tensor): The maximum scale for min-max tuning. - """ - super(WrapperTransformerConv1d, self).__init__() - self.orig_layer = orig_layer - self.bits = self.orig_layer.bits - self.group_size = self.orig_layer.group_size - self.sym = self.orig_layer.sym - self.scale_dtype = self.orig_layer.scale_dtype - self.data_type = self.orig_layer.data_type - self.act_bits = self.orig_layer.act_bits - self.act_group_size = self.orig_layer.act_group_size - self.act_sym = self.orig_layer.act_sym - self.act_dynamic = self.orig_layer.act_dynamic - self.act_quant = self.act_bits <= 8 - self.weight_quant_func, self.data_type = get_quant_func(self.orig_layer.data_type, self.bits, self.sym) - if self.act_quant: - self.act_quant_func, self.act_data_type = get_quant_func(self.orig_layer.data_type, self.act_bits, - self.act_sym) - - self.q_scale_thresh = 1e-5 - weight_dtype = torch.float32 - self.device = device - self.params = {} - if hasattr(self.orig_layer, 'get_weight'): - self.weight_t = self.orig_layer.get_weight().t() - else: - self.weight_t = self.orig_layer.weight.t() - self.weight_t = self.weight_t.to(self.device) - self.value = torch.nn.Parameter( - reshape_tensor(torch.zeros(self.weight_t.shape, device=device, dtype=weight_dtype), - group_size=self.group_size), - requires_grad=True - ) - self.params["v"] = self.value - weight_reshape = reshape_tensor(self.weight_t, self.group_size) - self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) - self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) - - shape = get_scale_shape(self.weight_t, self.group_size) - - if enable_minmax_tuning: - self.min_scale = torch.nn.Parameter( - torch.ones(shape, device=device, dtype=weight_dtype), requires_grad=True - ) - self.max_scale = torch.nn.Parameter( - torch.ones(shape, device=device, dtype=weight_dtype), requires_grad=True - ) - self.params["min_scale"] = self.min_scale - self.params["max_scale"] = self.max_scale - - else: - self.min_scale = torch.tensor(1.0, device=device, dtype=weight_dtype) - self.max_scale = torch.tensor(1.0, device=device, dtype=weight_dtype) - - self.enable_norm_bias_tuning = False - if enable_norm_bias_tuning and self.orig_layer.bias is not None: - self.enable_norm_bias_tuning = True - self.bias_bits = 4 ## hard code - self.bias_group_size = -1 - self.bias_v = torch.nn.Parameter( - reshape_tensor( - torch.zeros(self.orig_layer.bias.shape, device=self.device, dtype=weight_dtype), - self.bias_group_size), - requires_grad=True) - from auto_round.data_type.int import quant_tensor_asym_wo_round - self.bias_quant_func = quant_tensor_asym_wo_round - self.params["bias_v"] = self.bias_v - - def unwrapper(self, best_params): - """Unwrapper the layer to the original conv1d layer. - - Args: - - v (torch.Tensor): The scaling parameter for quantization. - - min_scale (torch.nn.Parameter or torch.Tensor): The minimum scale for min-max tuning. - - max_scale (torch.nn.Parameter or torch.Tensor): The maximum scale for min-max tuning. - - Returns: - - torch.nn.Module: The original 1D convolutional layer with updated weights after inverse quantization. - """ - min_scale = torch.tensor(1.0) - max_scale = torch.tensor(1.0) - v = torch.tensor(0.0) - if best_params is not None: - min_scale = best_params.get('min_scale', min_scale) - max_scale = best_params.get('max_scale', max_scale) - v = best_params.get('v', v) - - min_scale.clamp_(0, 1.0) - max_scale.clamp_(0, 1.0) - v = v.to(self.device) - min_scale = min_scale.to(self.device) - max_scale = max_scale.to(self.device) - - qdq_weight, scale, zp = quant_tensor(self.weight_quant_func, self.weight_t, self.bits, self.group_size, v, - min_scale, - max_scale, self.scale_dtype, self.weight_min, self.weight_max, - data_type=self.data_type, q_scale_thresh=self.q_scale_thresh) - scale = scale.reshape(qdq_weight.shape[0], -1) - if zp is not None: - zp = zp.reshape(qdq_weight.shape[0], -1) - if self.orig_layer.weight.device.type == 'meta': - self.orig_layer.weight.to(self.device) - self.orig_layer.weight.data.copy_(qdq_weight.t()) - self.orig_layer.weight.grad = None - - if self.enable_norm_bias_tuning and "bias_v" in best_params.keys(): ##fake quant - bias_v = best_params["bias_v"] - bias, _, _ = quant_tensor(self.bias_quant_func, self.orig_layer.bias, self.bias_bits, self.bias_group_size, - bias_v, q_scale_thresh=self.q_scale_thresh) - self.orig_layer.bias.grad = None - self.orig_layer.bias.data.copy_(bias) - - self.orig_layer.scale = scale.to("cpu") - self.orig_layer.zp = zp.to("cpu") - self.orig_layer.q_scale_thresh = self.q_scale_thresh - self.orig_layer.data_type = self.data_type - if self.act_quant: - self.orig_layer.act_quant_func = self.act_quant_func - self.orig_layer.act_data_type = self.act_data_type - wrapper_layer = WrapperWALayer(self.orig_layer) - return wrapper_layer - if hasattr(self.orig_layer, 'update'): - self.orig_layer.update() - self.orig_layer.to('meta') - return self.orig_layer - - def forward(self, x): - """Performs forward pass through the wrapped 1D convolutional layer with quantized weights. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying the convolutional transformation with quantized weights. - """ - with torch.no_grad(): - self.min_scale.clamp_(0, 1.0) - self.max_scale.clamp_(0, 1.0) - weight_q, _, _ = quant_tensor(self.weight_quant_func, self.weight_t, self.bits, self.group_size, self.value, - self.min_scale, self.max_scale, self.scale_dtype, self.weight_min, - self.weight_max, data_type=self.data_type, q_scale_thresh=self.q_scale_thresh) - weight_q = weight_q.to(self.weight_t.dtype) - size_out = x.size()[:-1] + (self.orig_layer.nf,) - if self.act_quant: - x, _, _ = quant_tensor(self.act_quant_func, x, self.act_bits, self.act_group_size, - scale_dtype=self.scale_dtype, q_scale_thresh=self.q_scale_thresh, - data_type=self.act_data_type) - bias = self.orig_layer.bias - if self.enable_norm_bias_tuning: - bias, _, _ = quant_tensor(self.bias_quant_func, bias, self.bias_bits, self.bias_group_size, self.bias_v, - q_scale_thresh=self.q_scale_thresh) - x = torch.addmm(bias, x.view(-1, x.size(-1)), weight_q.t()) - x = x.view(*size_out) - return x - - class WrapperMultiblock(torch.nn.Module): """A wrapper for a list of modules to be act as a single block. @@ -595,7 +458,7 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device=' quantized_layers = [] unquantized_layers = [] for n, m in block.named_modules(): - if isinstance(m, torch.nn.Linear): + if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): if not check_to_quantized(m): unquantized_layers.append(n) continue @@ -604,20 +467,18 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device=' set_module(block, n, new_m) quantized_layers.append(n) - if isinstance(m, transformers.modeling_utils.Conv1D): - if not check_to_quantized(m): - unquantized_layers.append(n) - continue - new_m = WrapperTransformerConv1d(m, enable_minmax_tuning=enable_minmax_tuning, device=device) - set_module(block, n, new_m) - quantized_layers.append(n) - if enable_norm_bias_tuning: if "norm" in m.__class__.__name__.lower(): if m.__class__.__name__ in norm_mapping.keys(): wrapper_layer_class = norm_mapping[m.__class__.__name__] new_m = wrapper_layer_class(m, device=device) setattr(block, n, new_m) + elif "RMSNorm" in m.__class__.__name__: + logger.warning_once( + f"use LlamaRMSNorm to wrap {m.__class__.__name__}, please check the correctness yourself") + wrapper_layer_class = norm_mapping["LlamaRMSNorm"] + new_m = wrapper_layer_class(m, device=device) + setattr(block, n, new_m) else: logger.warning_once(f"{m.__class__.__name__} is not supported") diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index 3a50ba01..8e92021c 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -149,6 +149,12 @@ def __init__(self, *args, **kwargs): self.add_argument("--enable_torch_compile", default=None, type=bool, help="whether to enable torch compile") + self.add_argument("--act_data_type", default=None, type=str, + help="activation data type") + + self.add_argument("--disable_act_dynamic", action='store_true', + help="activation static quantization") + def setup_parser(): parser = BasicArgumentParser() @@ -338,9 +344,9 @@ def tune(args): try: if args.model_dtype == "float16" or args.model_dtype == "fp16": model = model.to(torch.float16) - elif args.model_dtype == "bfloat16" or args.model_dtype == "bfp16" or args.model_dtype=="bf16": + elif args.model_dtype == "bfloat16" or args.model_dtype == "bfp16" or args.model_dtype == "bf16": model = model.to(torch.bfloat16) - elif args.model_dtype=="float32" or args.model_dtype=="fp32": + elif args.model_dtype == "float32" or args.model_dtype == "fp32": model = model.to(torch.float32) except: logger.error("please use more device to fit the device or just use one device") @@ -418,7 +424,8 @@ def tune(args): enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits, low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type, enable_norm_bias_tuning=args.enable_norm_bias_tuning, not_use_best_mse=args.not_use_best_mse, - to_quant_block_names=args.to_quant_block_names, enable_torch_compile=args.enable_torch_compile) + to_quant_block_names=args.to_quant_block_names, enable_torch_compile=args.enable_torch_compile, + act_data_type=args.act_data_type, act_dynamic=not args.disable_act_dynamic) model, _ = autoround.quantize() model_name = args.model.rstrip("/") if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: @@ -460,6 +467,7 @@ def tune(args): user_model = model else: user_model = model.to(device_str) + if args.eval_bs is None or args.eval_bs == "auto": args.eval_bs = 16 from auto_round.eval.evaluation import simple_evaluate_user_model diff --git a/test/test_act_quantization.py b/test/test_act_quantization.py new file mode 100644 index 00000000..a4ada07d --- /dev/null +++ b/test/test_act_quantization.py @@ -0,0 +1,90 @@ +import copy +import shutil +import sys +import unittest + +sys.path.insert(0, "..") +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(3): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRoundAct(unittest.TestCase): + @classmethod + def setUpClass(self): + model_name = "facebook/opt-125m" + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.llm_dataloader = LLMDataLoader() + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_mx_fp4(self): + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + bits, group_size, sym = 4, 128, True + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + act_bits=4, + data_type="mx_fp4" + ) + autoround.quantize() + + def test_wint4fp8_dynamic(self): + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + bits, group_size = 4, 128 + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + act_bits=8, + data_type="fp8_to_int_sym", + act_data_type="fp8_dynamic_per_token" + ) + autoround.quantize() + + def test_wint4fp8_static(self): + bits, group_size, sym = 4, 128, True + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + act_bits=8, + data_type="fp8_to_int_sym", + act_dynamic=False, + act_data_type="fp8" + ) + autoround.quantize() diff --git a/test/test_woq_linear.py b/test/test_woq_linear.py index 527ddd8d..f049890e 100644 --- a/test/test_woq_linear.py +++ b/test/test_woq_linear.py @@ -28,11 +28,9 @@ def test_pack_with_numba(self, bits, compression_dtype): group_size = 32 origin_shape = weight.shape from auto_round.data_type.int import quant_tensor_sym - from auto_round.quantizer import quant_tensor, reshape_tensor - - data = reshape_tensor(weight, group_size=group_size) - qdq, scale, zp = quant_tensor( - quant_tensor_sym, data=data, group_size=group_size + origin_shape = weight.shape + weight = weight.reshape(-1, group_size) + qdq, scale, zp = quant_tensor_sym( weight, -1 ) int_weight = ( qdq.div(scale) @@ -41,8 +39,8 @@ def test_pack_with_numba(self, bits, compression_dtype): .to(torch.int32) .reshape(origin_shape) ) - scale = scale.reshape(weight.shape[0], -1) - zp = zp.reshape(weight.shape[0], -1).to(torch.int32).clamp(0, 2 ** (bits) - 1) + scale = scale.reshape(origin_shape[0], -1) + zp = zp.reshape(origin_shape[0], -1).to(torch.int32).clamp(0, 2 ** (bits) - 1) module_with_legacy_pack = WeightOnlyLinear( in_features=m.in_features, out_features=m.out_features,