From 859f5d0e367977d94612caf66e818174e6d6dfb9 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 12:53:35 +0800 Subject: [PATCH 01/10] full range sym --- auto_round/data_type/int.py | 90 ++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index ef350895..53c574c9 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -17,6 +17,49 @@ from auto_round.data_type.register import register_dtype, QUANT_FUNC_WITH_DTYPE +@register_dtype("int_sym") +def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, + weight_max=None, q_scale_thresh=0.0, **kwargs): + """Quantizes and dequantizes weight asymmetrically. full range, credict goes to llamacpp community + + Args: + weight: Tensor containing the weight to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + v: Rounding value perturbation + min_scale: Minimum scale coefficient for weight + max_scale: Maximum scale coefficient for weight + weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. + weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. + + Returns: + Quantized and dequantized weight, scale, zero-point + """ + maxq = torch.tensor(2 ** (bits - 1)) + if weight_min is None or weight_max is None: + wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) + wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) + else: + wmin_tmp = weight_min + wmax_tmp = weight_max + if isinstance(min_scale, torch.Tensor): + wmin = wmin_tmp * min_scale + wmax = wmax_tmp * max_scale + else: + wmin = wmin_tmp + wmax = wmax_tmp + max_v = (2 * (torch.abs(wmax) < torch.abs(wmin)).int() - 1) * torch.max(torch.abs(wmax), torch.abs(wmin)) + + scale = (max_v / maxq).to(scale_dtype) + scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) + zp = torch.full_like(scale, maxq) # pylint: disable=E1130 + scale = scale.unsqueeze(dim=-1) + zp = zp.unsqueeze(dim=-1) + int_w = round_ste(weight / scale + v) + q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) + qdq_result = (scale * (q - zp)).to(weight.dtype) + return qdq_result, scale, zp + + @register_dtype("int_asym") def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): @@ -49,7 +92,7 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d wmax = wmax_tmp scale = ((wmax - wmin) / maxq).to(scale_dtype) scale = torch.clamp(scale, min=q_scale_thresh) - zp = round_ste(-wmin / scale) # pylint: disable=E1130 + zp = round_ste(-wmin / scale) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) int_w = round_ste(weight / scale + v) @@ -58,7 +101,7 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d return qdq_result, scale, zp -@register_dtype("int_sym") +@register_dtype("int_sym_gptq") def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): """Quantizes and dequantizes weight symmetrically. @@ -106,6 +149,49 @@ def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dt return qdq_result, scale, zp +# @register_dtype("int_sym") +# def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, +# weight_max=None, q_scale_thresh=0.0, **kwargs): +# """Quantizes and dequantizes weight symmetrically. +# +# Args: +# weight: Tensor containing the weight to be quantized +# bits: Number of bits for quantization (e.g., 2, 3, 4, 8) +# v: Rounding value perturbation +# min_scale: Minimum scale coefficient for weight +# max_scale: Maximum scale coefficient for weight +# weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. +# weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. +# +# Returns: +# Quantized and dequantized weight, scale, zero-point +# """ +# maxq = torch.tensor(2 ** (bits - 1)) ##different +# if weight_min is None or weight_max is None: +# wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) +# wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) +# else: +# wmin_tmp = weight_min +# wmax_tmp = weight_max +# if isinstance(min_scale, torch.Tensor): +# wmin = wmin_tmp * min_scale +# wmax = wmax_tmp * max_scale +# else: +# wmin = wmin_tmp +# wmax = wmax_tmp +# max_v = (2 * (torch.abs(wmax) < torch.abs(wmin)).int() - 1) * torch.max(torch.abs(wmax), torch.abs(wmin)) +# +# scale = (max_v / maxq).to(scale_dtype) +# scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) +# zp = torch.full_like(scale, maxq) # pylint: disable=E1130 +# scale = scale.unsqueeze(dim=-1) +# zp = zp.unsqueeze(dim=-1) +# int_w = round_ste(weight / scale + v) +# q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) +# qdq_result = (scale * (q - zp)).to(weight.dtype) +# return qdq_result, scale, zp + + def quant_tensor_asym_wo_round(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): """Quantizes and dequantizes weight asymmetrically without rounding, this is mainly for tuning bias, norm. From 8a59c2ca477aa03afd4c875edd9a30231c8f3758 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 13:14:16 +0800 Subject: [PATCH 02/10] update --- README.md | 26 +++++++++++++------------- auto_round/__main__.py | 14 ++++++++++---- auto_round/autoround.py | 22 ++++++---------------- examples/language-modeling/main.py | 14 +++++++------- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index e9ee5d9a..bff6695a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,9 @@ more accuracy data and recipes across various models.
## What's New -* [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) + +* [2024/09] AutoRound format supports several LVM models, check out the + examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) * [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer to [Intel/Qwen2-7B-int4-inc](https://huggingface.co/Intel/Qwen2-7B-int4-inc). * [2024/08] AutoRound introduces several experimental features, including fast tuning of norm/bias parameters (for 2-bit @@ -61,14 +63,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_name) from auto_round import AutoRound -bits, group_size, sym = 4, 128, False -autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym) +bits, group_size = 4, 128 +autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size) -## best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower -# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym) +## the best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower +# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size) ## fast and low memory, 2-3X speedup, slight accuracy drop at W4G128 -# autoround = AutoRound(model, tokenizer, nsamples=128, iters=200, seqlen=512, batch_size=4, bits=bits, group_size=group_size, sym=sym) +# autoround = AutoRound(model, tokenizer, nsamples=128, iters=200, seqlen=512, batch_size=4, bits=bits, group_size=group_size) autoround.quantize() output_dir = "./tmp_autoround" @@ -87,7 +89,7 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True) - `group_size (int)`: Size of the quantization group (default is 128). -- `sym (bool)`: Whether to use symmetric quantization (default is False). +- `sym (bool)`: Whether to use symmetric quantization (default is True). - `enable_quanted_input (bool)`: Whether to use the output of the previous quantized block as the input for the current block for tuning (default is True). @@ -173,7 +175,8 @@ We provide two recipes for best accuracy and fast running speed with low memory. #### Formats -**AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] +**AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision +inference. [2,4] bits are supported. It resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to @@ -216,7 +219,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) **HPU**: docker image with Gaudi Software Stack is recommended. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/). -**CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install --no-build-isolation +**CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install --no-build-isolation -e . #### CPU/HPU/CUDA on 0.3.0+ @@ -308,8 +311,8 @@ release most of the models ourselves. | bigscience/bloom-3b | [accuracy](./docs/bloom-3B-acc.md), [recipe](./examples/language-modeling/scripts/bloom-3b.sh), [example](./examples/language-modeling/) | | EleutherAI/gpt-j-6b | [accuracy](./docs/gpt-j-6B-acc.md), [recipe](./examples/language-modeling/scripts/gpt-j-6b.sh), [example](./examples/language-modeling/) | - ## Integration + AutoRound has been integrated into multiple repositories. [Intel Neural Compressor](https://github.com/intel/neural-compressor) @@ -318,9 +321,6 @@ AutoRound has been integrated into multiple repositories. [pytorch/ao](https://github.com/pytorch/ao) - - - ## Reference If you find AutoRound useful for your research, please cite our paper: diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 3de56571..8f58a65d 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -56,8 +56,8 @@ def setup_parser(): help="The device to be used for tuning. The default is set to auto/None," "allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.") - parser.add_argument("--sym", action='store_true', - help=" sym quantization") + parser.add_argument("--asym", action='store_true', + help=" asym quantization") parser.add_argument("--iters", default=200, type=int, help=" iters") @@ -92,7 +92,7 @@ def setup_parser(): parser.add_argument("--format", default=None, type=str, help="The format in which to save the model. " - "The options are 'auto_round', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'." + "The options are 'auto_round', 'auto_round:marlin', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'." "default to 'auto_round." ) @@ -161,6 +161,12 @@ def tune(args): tasks = args.tasks if args.format is None: args.format = "auto_round" + if "auto_gptq" in args.format and args.asym is True: + print( + "warning: The auto_gptq kernel has issues with asymmetric quantization. It is recommended to use sym quantization or --format='auto_round'") + + if "marlin" in args.format and args.asym is True: + assert False, "marlin backend only supports sym quantization, please remove --asym" model_name = args.model if model_name[-1] == "/": @@ -284,7 +290,7 @@ def tune(args): raise EnvironmentError(error_message) autoround = round( - model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size, + model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size, dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, seed=args.seed, diff --git a/auto_round/autoround.py b/auto_round/autoround.py index ecd0ed0c..c790f044 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -67,7 +67,7 @@ class AutoRound(object): tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). - sym (bool): Whether symmetric quantization is to be used (default is False). + sym (bool): Whether symmetric quantization is to be used (default is True). layer_config (dict): Configuration for weight quantization (default is None). layer_config={ 'layer1':##layer_name @@ -75,7 +75,7 @@ class AutoRound(object): 'data_type': 'int', 'bits': 4, 'group_size': 128, - 'sym': False + 'sym': True 'act_data_type': None, 'act_bits': 32, 'act_group_size': None, @@ -84,7 +84,6 @@ class AutoRound(object): } ... } - enable_full_range (bool): Whether to enable full range quantization (default is False). batch_size (int): Batch size for training (default is 8). amp (bool): Whether to use automatic mixed precision (default is True). device: The device to be used for tuning (default is "auto"). @@ -126,7 +125,6 @@ def __init__( group_size: int = 128, sym: bool = False, layer_config: dict = None, - enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, amp: bool = True, device: str = None, @@ -207,7 +205,6 @@ def __init__( self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap - self.enable_full_range = enable_full_range self.lr_scheduler = lr_scheduler self.optimizer = self.get_optimizer(None) self.share_attention_mask_flag = None @@ -243,7 +240,6 @@ def check_configs(self): assert self.seqlen > 0, "seqlen must be positive" assert self.nblocks > 0, "nblocks must be positive" assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive" - assert self.enable_full_range is False, "only support enable_full_range=False currently" assert self.act_dynamic is True, "only support dynamic quantization for activation currently" # assert self.tokenizer != None or self.dataloader != None if self.act_bits <= 8: @@ -1345,9 +1341,8 @@ class AutoOPTRound(AutoRound): tokenizer: An optional tokenizer for processing input data. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). - sym (bool): Whether sym to be used (default is False). + sym (bool): Whether sym to be used (default is True). layer_config (dict): Configuration for weight quantization (default is None). - enable_full_range (bool): Whether to enable full range quantization (default is False). batch_size (int): Batch size for training (default is 8). amp (bool): Whether to use automatic mixed precision (default is True). device: The device to be used for training (default is "auto"). @@ -1388,9 +1383,8 @@ def __init__( tokenizer=None, bits: int = 4, group_size: int = 128, - sym: bool = False, + sym: bool = True, layer_config=None, - enable_full_range: bool = False, batch_size: int = 8, amp: bool = True, device=None, @@ -1429,7 +1423,6 @@ def __init__( group_size=group_size, sym=sym, layer_config=layer_config, - enable_full_range=enable_full_range, batch_size=batch_size, amp=amp, device=device, @@ -1514,9 +1507,8 @@ class AutoAdamRound(AutoOPTRound): tokenizer: An optional tokenizer for processing input data. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). - sym (str): Whether symmetric quantization to be used (default is False). + sym (str): Whether symmetric quantization to be used (default is True). layer_config (dict): Configuration for weight quantization (default is None). - enable_full_range (bool): Whether to enable full range quantization (default is False). batch_size (int): Batch size for training (default is 8). amp (bool): Whether to use automatic mixed precision (default is True). device: The device to be used for training (default is "auto"). @@ -1557,9 +1549,8 @@ def __init__( tokenizer=None, bits: int = 4, group_size: int = 128, - sym: bool = False, + sym: bool = True, layer_config=None, - enable_full_range: bool = False, batch_size: int = 8, amp: bool = True, device=None, @@ -1598,7 +1589,6 @@ def __init__( group_size=group_size, sym=sym, layer_config=layer_config, - enable_full_range=enable_full_range, batch_size=batch_size, amp=amp, device=device, diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py index e40ed069..b85c80f2 100644 --- a/examples/language-modeling/main.py +++ b/examples/language-modeling/main.py @@ -50,8 +50,8 @@ help="The device to be used for tuning. The default is set to auto/None," "allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.") - parser.add_argument("--sym", action='store_true', - help=" sym quantization") + parser.add_argument("--asym", action='store_true', + help=" asym quantization") parser.add_argument("--iters", default=200, type=int, help=" iters") @@ -184,12 +184,12 @@ if args.deployment_device: warnings.warn("The deployment_device is deprecated and will be removed in future version." "Please use format instead", DeprecationWarning) - if "gpu" in args.deployment_device and args.sym is False: + if "gpu" in args.deployment_device and args.asym is True: print( - "warning: The auto_gptq kernel has issues with asymmetric quantization. It is recommended to use --format='auto_round'") + "warning: The auto_gptq kernel has issues with asymmetric quantization. It is recommended to use sym quantization or --format='auto_round'") - if "marlin" in args.deployment_device and args.sym is False: - assert False, "marlin backend only supports sym quantization, please set --sym" + if "marlin" in args.deployment_device and args.asym is True: + assert False, "marlin backend only supports sym quantization, please remove --asym" model_name = args.model_name if model_name[-1] == "/": @@ -323,7 +323,7 @@ error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization." raise EnvironmentError(error_message) - autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + autoround = round(model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.train_bs, dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, From 02c63deb445ecd82d98b0ae1575ac91a353de062 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 13:16:49 +0800 Subject: [PATCH 03/10] fix --- auto_round/data_type/int.py | 44 +------------------------------------ 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 53c574c9..2ebd5ced 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -102,7 +102,7 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d @register_dtype("int_sym_gptq") -def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, +def quant_tensor_sym_gptq(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): """Quantizes and dequantizes weight symmetrically. @@ -149,48 +149,6 @@ def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dt return qdq_result, scale, zp -# @register_dtype("int_sym") -# def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, -# weight_max=None, q_scale_thresh=0.0, **kwargs): -# """Quantizes and dequantizes weight symmetrically. -# -# Args: -# weight: Tensor containing the weight to be quantized -# bits: Number of bits for quantization (e.g., 2, 3, 4, 8) -# v: Rounding value perturbation -# min_scale: Minimum scale coefficient for weight -# max_scale: Maximum scale coefficient for weight -# weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None. -# weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None. -# -# Returns: -# Quantized and dequantized weight, scale, zero-point -# """ -# maxq = torch.tensor(2 ** (bits - 1)) ##different -# if weight_min is None or weight_max is None: -# wmin_tmp = torch.clamp(weight.min(-1)[0], max=0) -# wmax_tmp = torch.clamp(weight.max(-1)[0], min=0) -# else: -# wmin_tmp = weight_min -# wmax_tmp = weight_max -# if isinstance(min_scale, torch.Tensor): -# wmin = wmin_tmp * min_scale -# wmax = wmax_tmp * max_scale -# else: -# wmin = wmin_tmp -# wmax = wmax_tmp -# max_v = (2 * (torch.abs(wmax) < torch.abs(wmin)).int() - 1) * torch.max(torch.abs(wmax), torch.abs(wmin)) -# -# scale = (max_v / maxq).to(scale_dtype) -# scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) -# zp = torch.full_like(scale, maxq) # pylint: disable=E1130 -# scale = scale.unsqueeze(dim=-1) -# zp = zp.unsqueeze(dim=-1) -# int_w = round_ste(weight / scale + v) -# q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) -# qdq_result = (scale * (q - zp)).to(weight.dtype) -# return qdq_result, scale, zp - def quant_tensor_asym_wo_round(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs): From be31d341ebf1f1b93e683752459b6ae298865d06 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 14:09:54 +0800 Subject: [PATCH 04/10] update kernel side --- README.md | 15 ++++++++------- auto_round/autoround.py | 3 +-- auto_round/export/export_to_autoround/export.py | 4 ++++ requirements.txt | 3 ++- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index bff6695a..3c0ed7d2 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,9 @@ more accuracy data and recipes across various models. ## What's New +* [2024/10] Important update: We now support full-range symmetric quantization and have made it the default + configuration. This approach is typically better or comparable to asymmetric quantization and significantly + outperforms other symmetric variants, especially at low bit-widths like 2-bit. * [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) * [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer @@ -177,9 +180,9 @@ We provide two recipes for best accuracy and fast running speed with low memory. **AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] -bits are supported. It -resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization -and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to +bits are supported. It also benefits +from the Marlin kernel, which can boost inference performance notably.However, it has not yet gained widespread +community adoption. For CUDA support, you will need to install from the source. **AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the @@ -190,8 +193,7 @@ models. Additionally, symmetric quantization tends to perform poorly at 2-bit precision. **AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted -within the community, only 4-bits quantization is supported. Asymmetric quantization typically improves -accuracy but may reduce inference speed. It features +within the community, only 4-bits quantization is supported. It features specialized layer fusion tailored for Llama models. ## Model Inference @@ -219,8 +221,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) **HPU**: docker image with Gaudi Software Stack is recommended. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/). -**CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install --no-build-isolation --e . +**CUDA**: pip install auto-gptq for sym quantization, for asym quantization, need to install auto-round from source #### CPU/HPU/CUDA on 0.3.0+ diff --git a/auto_round/autoround.py b/auto_round/autoround.py index c790f044..3bde5b6f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -58,8 +58,7 @@ class AutoRound(object): - """This is Signround+ which is an advanced version of Signround. For more information, - please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent + """For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023). Args: diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index e46b4800..1e50edb2 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -180,6 +180,10 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex backend = "auto_round:exllamav2" backend = backend.replace("autoround", "auto_round") backend = backend.replace("auto-round", "auto_round") + ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source + if kwargs.get("sym") is None or kwargs["sym"] == True and "gptq" not in backend: + backend = backend.replace('auto_round','auto_round:gptq') + if not ("triton" in backend or "exllamav2" in backend or "awq" in backend or "gptq" in backend): logger.info(f"auto_round format does not support {backend}, try to pack each layer with autogptq") backend = backend.replace("auto_round", "auto_gptq") diff --git a/requirements.txt b/requirements.txt index 0cc1327b..5755cf0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ numpy < 2.0 threadpoolctl lm-eval==0.4.4 tqdm -packaging \ No newline at end of file +packaging +auto-gptq>=0.7.1 \ No newline at end of file From 18136fb3d56e734ebed5f68500a995f8fb0c9ad5 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 14:43:06 +0800 Subject: [PATCH 05/10] add accuracy data --- README.md | 8 ++++++-- auto_round/__main__.py | 6 ++++-- auto_round/utils.py | 8 ++++---- docs/full_range_sym.md | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 docs/full_range_sym.md diff --git a/README.md b/README.md index 3c0ed7d2..8c2b3ae9 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,9 @@ more accuracy data and recipes across various models. * [2024/10] Important update: We now support full-range symmetric quantization and have made it the default configuration. This approach is typically better or comparable to asymmetric quantization and significantly - outperforms other symmetric variants, especially at low bit-widths like 2-bit. + outperforms other symmetric variants, especially at low bit-widths like 2-bit. No need to compile from source to run + AutoRound format anymore. + * [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) * [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer @@ -221,7 +223,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) **HPU**: docker image with Gaudi Software Stack is recommended. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/). -**CUDA**: pip install auto-gptq for sym quantization, for asym quantization, need to install auto-round from source +**CUDA**: pip install auto-gptq for sym quantization(tuning needs auto-round 0.30+), for asym quantization, need to install auto-round from source #### CPU/HPU/CUDA on 0.3.0+ @@ -244,6 +246,8 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) #### CPU/HPU/CUDA on 0.3.0 +**CUDA**: need to install auto-round from source + ```python from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round.auto_quantizer import AutoHfQuantizer ## must import diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 8f58a65d..8f8fcf41 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -92,7 +92,8 @@ def setup_parser(): parser.add_argument("--format", default=None, type=str, help="The format in which to save the model. " - "The options are 'auto_round', 'auto_round:marlin', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'." + "The options are 'auto_round', 'auto_round:gptq','auto_round:marlin'," + " 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'." "default to 'auto_round." ) @@ -163,7 +164,8 @@ def tune(args): args.format = "auto_round" if "auto_gptq" in args.format and args.asym is True: print( - "warning: The auto_gptq kernel has issues with asymmetric quantization. It is recommended to use sym quantization or --format='auto_round'") + "warning: The auto_gptq kernel has issues with asymmetric quantization. " + "It is recommended to use sym quantization or --format='auto_round'") if "marlin" in args.format and args.asym is True: assert False, "marlin backend only supports sym quantization, please remove --asym" diff --git a/auto_round/utils.py b/auto_round/utils.py index 05593c02..2f1c0891 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -808,8 +808,8 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False): disable_exllama=disable_exllamav1, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen, - use_marlin=not disable_marlin, - use_tritonv2=use_tritonv2 + use_marlin=not disable_marlin, # pylint: disable=E1123 + use_tritonv2=use_tritonv2 # pylint: disable=E1123 ) return QuantLinear @@ -967,8 +967,8 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): bits=bits, disable_exllama=disable_exllamav1, disable_exllamav2=disable_exllamav2, - use_qigen=use_qigen, - use_marlin=not disable_marlin, + use_qigen=use_qigen, # pylint: disable=E1123 + use_marlin=not disable_marlin, # pylint: disable=E1123 ) return QuantLinear diff --git a/docs/full_range_sym.md b/docs/full_range_sym.md new file mode 100644 index 00000000..2165bb24 --- /dev/null +++ b/docs/full_range_sym.md @@ -0,0 +1,15 @@ +W2G32 nsamples 512,iter 200, average accuracy of 10 tasks + +| Models | gptq_sym | asym | full_range_sym | +|----------------------------|----------|------------|----------------| +| Meta-Llama-3.1-8B-Instruct | 0.4500 | 0.52802 | **0.5381** | +| Qwen2-7B | 0.5229 | **0.5559** | 0.5486 | + +W4G128 nsamples 128,iter 200, average accuracy of 10 tasks + +| Models | asym | full_range_sym | +|----------------------------|------------|----------------| +| Meta-Llama-3.1-8B-Instruct | 0.6342 | **0.6370** | +| Qwen2-7B | 0.6143 | **0.6167** | +| Mistral-7B-Instruct-v0.2 | 0.6606 | **0.6635** | +| Phi-3-mini-4k-instruct | **0.6475** | 0.6432 | From 53ce9f3af4ce74d7aadad045cb85cfa4362264b1 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 14:51:09 +0800 Subject: [PATCH 06/10] fix one issue --- auto_round/export/export_to_autoround/export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 1e50edb2..0ab2f505 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -181,11 +181,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex backend = backend.replace("autoround", "auto_round") backend = backend.replace("auto-round", "auto_round") ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source - if kwargs.get("sym") is None or kwargs["sym"] == True and "gptq" not in backend: + if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend): backend = backend.replace('auto_round','auto_round:gptq') if not ("triton" in backend or "exllamav2" in backend or "awq" in backend or "gptq" in backend): - logger.info(f"auto_round format does not support {backend}, try to pack each layer with autogptq") + logger.info(f"AutoRound format does not support {backend}, try to pack each layer with AutoGPTQ") backend = backend.replace("auto_round", "auto_gptq") model = kwargs["model"] From f0c330d49c754d32e9627839bdb80f5939c1dc2c Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 15:01:17 +0800 Subject: [PATCH 07/10] try fo fix pylint issue --- auto_round/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/auto_round/utils.py b/auto_round/utils.py index 2f1c0891..10beff71 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -800,7 +800,7 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False): disable_marlin=disable_marlin ) else: - QuantLinear = dynamically_import_QuantLinear( + QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123 use_triton=use_triton, desc_act=False, group_size=group_size, @@ -808,8 +808,8 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False): disable_exllama=disable_exllamav1, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen, - use_marlin=not disable_marlin, # pylint: disable=E1123 - use_tritonv2=use_tritonv2 # pylint: disable=E1123 + use_marlin=not disable_marlin, + use_tritonv2=use_tritonv2 ) return QuantLinear @@ -960,15 +960,15 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): disable_marlin=disable_marlin, ) else: - QuantLinear = dynamically_import_QuantLinear( + QuantLinear = dynamically_import_QuantLinear(# pylint: disable=E1123 use_triton=use_triton, desc_act=False, group_size=group_size, bits=bits, disable_exllama=disable_exllamav1, disable_exllamav2=disable_exllamav2, - use_qigen=use_qigen, # pylint: disable=E1123 - use_marlin=not disable_marlin, # pylint: disable=E1123 + use_qigen=use_qigen, + use_marlin=not disable_marlin, ) return QuantLinear From ac63b5f08ffc561be4fc2069e6b54d3a34693b51 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 15:27:27 +0800 Subject: [PATCH 08/10] try to fix unit test --- test/test_autoround_acc.py | 2 +- test/test_block_names.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_autoround_acc.py b/test/test_autoround_acc.py index 53f64ae5..fe33c482 100644 --- a/test/test_autoround_acc.py +++ b/test/test_autoround_acc.py @@ -67,7 +67,7 @@ def test_default_acc(self): out1 = model_tmp(inp) assert out0[0].equal(out1[0]) - self.assertTrue(isclose(float(out0[0][0][0][0]), -0.02076812833547592, rel_tol=1e-04)) + self.assertTrue(isclose(float(out0[0][0][0][0]), -0.02076812833547592, rel_tol=1e-03)) if __name__ == "__main__": diff --git a/test/test_block_names.py b/test/test_block_names.py index bee555a2..6e101f4a 100644 --- a/test/test_block_names.py +++ b/test/test_block_names.py @@ -168,6 +168,8 @@ def test_block_name_quant(self): import auto_gptq except: return + if not torch.cuda.is_available(): + return quantized_model_path = "./saved" autoround.save_quantized(quantized_model_path, inplace=False, safe_serialization=False, format="auto_round") From 55f52a93f03a21ee37c4f2916cabbafbc273e8ac Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 17:40:03 +0800 Subject: [PATCH 09/10] fix ut as sym logic has been changed --- README.md | 3 +-- test/test_autoround_acc.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8c2b3ae9..31ace144 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,8 @@ more accuracy data and recipes across various models. * [2024/10] Important update: We now support full-range symmetric quantization and have made it the default configuration. This approach is typically better or comparable to asymmetric quantization and significantly - outperforms other symmetric variants, especially at low bit-widths like 2-bit. No need to compile from source to run + outperforms other symmetric variants, especially at low bit-widths like 2-bit. And,no need to compile from source to run AutoRound format anymore. - * [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) * [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer diff --git a/test/test_autoround_acc.py b/test/test_autoround_acc.py index fe33c482..ba4ed7a9 100644 --- a/test/test_autoround_acc.py +++ b/test/test_autoround_acc.py @@ -67,7 +67,8 @@ def test_default_acc(self): out1 = model_tmp(inp) assert out0[0].equal(out1[0]) - self.assertTrue(isclose(float(out0[0][0][0][0]), -0.02076812833547592, rel_tol=1e-03)) + print(float(out0[0][0][0][0])) + self.assertTrue(isclose(float(out0[0][0][0][0]), -0.021002087742090225, rel_tol=1e-04)) if __name__ == "__main__": From 93f5951f6161686cd288155830b502dcb6533171 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Sat, 12 Oct 2024 17:47:06 +0800 Subject: [PATCH 10/10] remove print --- test/test_autoround_acc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_autoround_acc.py b/test/test_autoround_acc.py index ba4ed7a9..9e741bd7 100644 --- a/test/test_autoround_acc.py +++ b/test/test_autoround_acc.py @@ -67,7 +67,6 @@ def test_default_acc(self): out1 = model_tmp(inp) assert out0[0].equal(out1[0]) - print(float(out0[0][0][0][0])) self.assertTrue(isclose(float(out0[0][0][0][0]), -0.021002087742090225, rel_tol=1e-04))