Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Important Change]set full range sym as the default #278

Merged
merged 10 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ more accuracy data and recipes across various models.
<div align="left">

## What's New
* [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava)

* [2024/10] Important update: We now support full-range symmetric quantization and have made it the default
configuration. This approach is typically better or comparable to asymmetric quantization and significantly
outperforms other symmetric variants, especially at low bit-widths like 2-bit. No need to compile from source to run
AutoRound format anymore.

* [2024/09] AutoRound format supports several LVM models, check out the
examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava)
* [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer
to [Intel/Qwen2-7B-int4-inc](https://huggingface.co/Intel/Qwen2-7B-int4-inc).
* [2024/08] AutoRound introduces several experimental features, including fast tuning of norm/bias parameters (for 2-bit
Expand Down Expand Up @@ -61,14 +68,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)

from auto_round import AutoRound

bits, group_size, sym = 4, 128, False
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)
bits, group_size = 4, 128
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size)

## best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym)
## the best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size)

## fast and low memory, 2-3X speedup, slight accuracy drop at W4G128
# autoround = AutoRound(model, tokenizer, nsamples=128, iters=200, seqlen=512, batch_size=4, bits=bits, group_size=group_size, sym=sym)
# autoround = AutoRound(model, tokenizer, nsamples=128, iters=200, seqlen=512, batch_size=4, bits=bits, group_size=group_size)

autoround.quantize()
output_dir = "./tmp_autoround"
Expand All @@ -87,7 +94,7 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True)

- `group_size (int)`: Size of the quantization group (default is 128).

- `sym (bool)`: Whether to use symmetric quantization (default is False).
- `sym (bool)`: Whether to use symmetric quantization (default is True).

- `enable_quanted_input (bool)`: Whether to use the output of the previous quantized block as the input for the current
block for tuning (default is True).
Expand Down Expand Up @@ -173,10 +180,11 @@ We provide two recipes for best accuracy and fast running speed with low memory.

#### Formats

**AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4]
bits are supported. It
resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization
and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to
**AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision
inference. [2,4]
bits are supported. It also benefits
from the Marlin kernel, which can boost inference performance notably.However, it has not yet gained widespread
community adoption. For CUDA support, you will need to
install from the source.

**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the
Expand All @@ -187,8 +195,7 @@ models.
Additionally, symmetric quantization tends to perform poorly at 2-bit precision.

**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted
within the community, only 4-bits quantization is supported. Asymmetric quantization typically improves
accuracy but may reduce inference speed. It features
within the community, only 4-bits quantization is supported. It features
specialized layer fusion tailored for Llama models.

## Model Inference
Expand Down Expand Up @@ -216,8 +223,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
**HPU**: docker image with Gaudi Software Stack is recommended. More details can be found
in [Gaudi Guide](https://docs.habana.ai/en/latest/).

**CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install --no-build-isolation
-e .
**CUDA**: pip install auto-gptq for sym quantization(tuning needs auto-round 0.30+), for asym quantization, need to install auto-round from source

#### CPU/HPU/CUDA on 0.3.0+

Expand All @@ -240,6 +246,8 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

#### CPU/HPU/CUDA on 0.3.0

**CUDA**: need to install auto-round from source

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round.auto_quantizer import AutoHfQuantizer ## must import
Expand Down Expand Up @@ -308,8 +316,8 @@ release most of the models ourselves.
| bigscience/bloom-3b | [accuracy](./docs/bloom-3B-acc.md), [recipe](./examples/language-modeling/scripts/bloom-3b.sh), [example](./examples/language-modeling/) |
| EleutherAI/gpt-j-6b | [accuracy](./docs/gpt-j-6B-acc.md), [recipe](./examples/language-modeling/scripts/gpt-j-6b.sh), [example](./examples/language-modeling/) |


## Integration

AutoRound has been integrated into multiple repositories.

[Intel Neural Compressor](https://github.com/intel/neural-compressor)
Expand All @@ -318,9 +326,6 @@ AutoRound has been integrated into multiple repositories.

[pytorch/ao](https://github.com/pytorch/ao)




## Reference

If you find AutoRound useful for your research, please cite our paper:
Expand Down
16 changes: 12 additions & 4 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def setup_parser():
help="The device to be used for tuning. The default is set to auto/None,"
"allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.")

parser.add_argument("--sym", action='store_true',
help=" sym quantization")
parser.add_argument("--asym", action='store_true',
help=" asym quantization")

parser.add_argument("--iters", default=200, type=int,
help=" iters")
Expand Down Expand Up @@ -92,7 +92,8 @@ def setup_parser():

parser.add_argument("--format", default=None, type=str,
help="The format in which to save the model. "
"The options are 'auto_round', 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'."
"The options are 'auto_round', 'auto_round:gptq','auto_round:marlin',"
" 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'."
"default to 'auto_round."
)

Expand Down Expand Up @@ -161,6 +162,13 @@ def tune(args):
tasks = args.tasks
if args.format is None:
args.format = "auto_round"
if "auto_gptq" in args.format and args.asym is True:
print(
"warning: The auto_gptq kernel has issues with asymmetric quantization. "
"It is recommended to use sym quantization or --format='auto_round'")

if "marlin" in args.format and args.asym is True:
assert False, "marlin backend only supports sym quantization, please remove --asym"

model_name = args.model
if model_name[-1] == "/":
Expand Down Expand Up @@ -284,7 +292,7 @@ def tune(args):
raise EnvironmentError(error_message)

autoround = round(
model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.batch_size,
model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size,
dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input,
device=device_str, amp=not args.disable_amp, nsamples=args.nsamples, seed=args.seed,
Expand Down
25 changes: 7 additions & 18 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,23 @@


class AutoRound(object):
"""This is Signround+ which is an advanced version of Signround. For more information,
please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent
"""For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent
for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023).

Args:
model: The PyTorch model to be quantized.
tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether symmetric quantization is to be used (default is False).
sym (bool): Whether symmetric quantization is to be used (default is True).
layer_config (dict): Configuration for weight quantization (default is None).
layer_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 128,
'sym': False
'sym': True
'act_data_type': None,
'act_bits': 32,
'act_group_size': None,
Expand All @@ -84,7 +83,6 @@ class AutoRound(object):
}
...
}
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "auto").
Expand Down Expand Up @@ -126,7 +124,6 @@ def __init__(
group_size: int = 128,
sym: bool = False,
layer_config: dict = None,
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
device: str = None,
Expand Down Expand Up @@ -207,7 +204,6 @@ def __init__(
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.enable_full_range = enable_full_range
self.lr_scheduler = lr_scheduler
self.optimizer = self.get_optimizer(None)
self.share_attention_mask_flag = None
Expand Down Expand Up @@ -243,7 +239,6 @@ def check_configs(self):
assert self.seqlen > 0, "seqlen must be positive"
assert self.nblocks > 0, "nblocks must be positive"
assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive"
assert self.enable_full_range is False, "only support enable_full_range=False currently"
assert self.act_dynamic is True, "only support dynamic quantization for activation currently"
# assert self.tokenizer != None or self.dataloader != None
if self.act_bits <= 8:
Expand Down Expand Up @@ -1345,9 +1340,8 @@ class AutoOPTRound(AutoRound):
tokenizer: An optional tokenizer for processing input data.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether sym to be used (default is False).
sym (bool): Whether sym to be used (default is True).
layer_config (dict): Configuration for weight quantization (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for training (default is "auto").
Expand Down Expand Up @@ -1388,9 +1382,8 @@ def __init__(
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = False,
sym: bool = True,
layer_config=None,
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
device=None,
Expand Down Expand Up @@ -1429,7 +1422,6 @@ def __init__(
group_size=group_size,
sym=sym,
layer_config=layer_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
amp=amp,
device=device,
Expand Down Expand Up @@ -1514,9 +1506,8 @@ class AutoAdamRound(AutoOPTRound):
tokenizer: An optional tokenizer for processing input data.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (str): Whether symmetric quantization to be used (default is False).
sym (str): Whether symmetric quantization to be used (default is True).
layer_config (dict): Configuration for weight quantization (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for training (default is "auto").
Expand Down Expand Up @@ -1557,9 +1548,8 @@ def __init__(
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = False,
sym: bool = True,
layer_config=None,
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
device=None,
Expand Down Expand Up @@ -1598,7 +1588,6 @@ def __init__(
group_size=group_size,
sym=sym,
layer_config=layer_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
amp=amp,
device=device,
Expand Down
50 changes: 47 additions & 3 deletions auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,49 @@
from auto_round.data_type.register import register_dtype, QUANT_FUNC_WITH_DTYPE


@register_dtype("int_sym")
def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None,
weight_max=None, q_scale_thresh=0.0, **kwargs):
"""Quantizes and dequantizes weight asymmetrically. full range, credict goes to llamacpp community

Args:
weight: Tensor containing the weight to be quantized
bits: Number of bits for quantization (e.g., 2, 3, 4, 8)
v: Rounding value perturbation
min_scale: Minimum scale coefficient for weight
max_scale: Maximum scale coefficient for weight
weight_min (Tensor, optional): Minimum weight value for quantization. Defaults to None.
weight_max (Tensor, optional): Maximum weight value for quantization. Defaults to None.

Returns:
Quantized and dequantized weight, scale, zero-point
"""
maxq = torch.tensor(2 ** (bits - 1))
if weight_min is None or weight_max is None:
wmin_tmp = torch.clamp(weight.min(-1)[0], max=0)
wmax_tmp = torch.clamp(weight.max(-1)[0], min=0)
else:
wmin_tmp = weight_min
wmax_tmp = weight_max
if isinstance(min_scale, torch.Tensor):
wmin = wmin_tmp * min_scale
wmax = wmax_tmp * max_scale
else:
wmin = wmin_tmp
wmax = wmax_tmp
max_v = (2 * (torch.abs(wmax) < torch.abs(wmin)).int() - 1) * torch.max(torch.abs(wmax), torch.abs(wmin))

scale = (max_v / maxq).to(scale_dtype)
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
scale = scale.unsqueeze(dim=-1)
zp = zp.unsqueeze(dim=-1)
int_w = round_ste(weight / scale + v)
q = torch.clamp(int_w + zp, 0, 2 ** bits - 1)
qdq_result = (scale * (q - zp)).to(weight.dtype)
return qdq_result, scale, zp


@register_dtype("int_asym")
def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs):
Expand Down Expand Up @@ -49,7 +92,7 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d
wmax = wmax_tmp
scale = ((wmax - wmin) / maxq).to(scale_dtype)
scale = torch.clamp(scale, min=q_scale_thresh)
zp = round_ste(-wmin / scale) # pylint: disable=E1130
zp = round_ste(-wmin / scale) # pylint: disable=E1130
scale = scale.unsqueeze(dim=-1)
zp = zp.unsqueeze(dim=-1)
int_w = round_ste(weight / scale + v)
Expand All @@ -58,8 +101,8 @@ def quant_tensor_asym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_d
return qdq_result, scale, zp


@register_dtype("int_sym")
def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None,
@register_dtype("int_sym_gptq")
def quant_tensor_sym_gptq(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None,
weight_max=None, q_scale_thresh=0.0, **kwargs):
"""Quantizes and dequantizes weight symmetrically.

Expand Down Expand Up @@ -106,6 +149,7 @@ def quant_tensor_sym(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dt
return qdq_result, scale, zp



def quant_tensor_asym_wo_round(weight, bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
weight_min=None, weight_max=None, q_scale_thresh=0.0, **kwargs):
"""Quantizes and dequantizes weight asymmetrically without rounding, this is mainly for tuning bias, norm.
Expand Down
6 changes: 5 additions & 1 deletion auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,12 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
backend = "auto_round:exllamav2"
backend = backend.replace("autoround", "auto_round")
backend = backend.replace("auto-round", "auto_round")
##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend):
backend = backend.replace('auto_round','auto_round:gptq')

if not ("triton" in backend or "exllamav2" in backend or "awq" in backend or "gptq" in backend):
logger.info(f"auto_round format does not support {backend}, try to pack each layer with autogptq")
logger.info(f"AutoRound format does not support {backend}, try to pack each layer with AutoGPTQ")
backend = backend.replace("auto_round", "auto_gptq")

model = kwargs["model"]
Expand Down
4 changes: 2 additions & 2 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False):
disable_marlin=disable_marlin
)
else:
QuantLinear = dynamically_import_QuantLinear(
QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123
use_triton=use_triton,
desc_act=False,
group_size=group_size,
Expand Down Expand Up @@ -960,7 +960,7 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
disable_marlin=disable_marlin,
)
else:
QuantLinear = dynamically_import_QuantLinear(
QuantLinear = dynamically_import_QuantLinear(# pylint: disable=E1123
use_triton=use_triton,
desc_act=False,
group_size=group_size,
Expand Down
Loading
Loading