-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: xin3he <[email protected]> Signed-off-by: Cheng, Zixuan <[email protected]> Signed-off-by: Kaihui-intel <[email protected]> Signed-off-by: zehao-intel <[email protected]> Signed-off-by: yiliu30 <[email protected]>
- Loading branch information
Showing
17 changed files
with
1,330 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
Dynamic Quantization | ||
=============== | ||
|
||
1. [Introduction](#introduction) | ||
2. [Getting Started with Dynamic Quantization](#Getting-Started-with-Dynamic-Quantization) | ||
3. [Examples](#examples) | ||
|
||
|
||
## Introduction | ||
Quantization is the process of converting floating point weights and activations to lower bitwidth tensors by multiplying the floating point values by a scale factor and rounding the results to whole numbers. Dynamic quantization determines the scale factor for activations dynamically based on the data range observed at runtime. We support W8A8 (quantizing weights and activations into 8 bits) dynamic quantization by leveraging torch's [`X86InductorQuantizer`](https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html?highlight=x86inductorquantizer). | ||
|
||
|
||
## Getting Started with Dynamic Quantization | ||
There are four steps to perform W8A8 dynamic quantization: `export`, `prepare`, `convert` and `compile`. | ||
|
||
```python | ||
import torch | ||
from neural_compressor.torch.export import export | ||
from neural_compressor.torch.quantization import DynamicQuantConfig, prepare, convert | ||
|
||
# Prepare the float model and example inputs for export model | ||
model = UserFloatModel() | ||
example_inputs = ... | ||
|
||
# Export eager model into FX graph model | ||
exported_model = export(model=model, example_inputs=example_inputs) | ||
# Quantize the model | ||
quant_config = DynamicQuantConfig() | ||
prepared_model = prepare(exported_model, quant_config=quant_config) | ||
q_model = convert(prepared_model) | ||
# Compile the quantized model and replace the Q/DQ pattern with Q-operator | ||
from torch._inductor import config | ||
|
||
config.freezing = True | ||
opt_model = torch.compile(q_model) | ||
``` | ||
|
||
> Note: The `set_local` of `DynamicQuantConfig` will be supported after the torch 2.4 release. | ||
|
||
## Examples | ||
Example will be added later. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
Microscaling Quantization | ||
=============== | ||
|
||
1. [Introduction](#introduction) | ||
2. [Get Started with Microscaling Quantization API](#get-start-with-microscaling-quantization-api) | ||
3. [Examples](#examples) | ||
4. [Reference](#reference) | ||
|
||
## Introduction | ||
|
||
Numerous breakthroughs have emerged across various fields, such as text analysis, language translation and chatbot technologies, fueled by the development of large language models (LLMs). Nevertheless, their increasing power comes with the challenge of explosive growth in parameters, posing obstacles for practical use. To balance memory limits and accuracy preservation for AI models, the Microscaling (MX) specification was promoted from the well-known Microsoft Floating Point (MSFP) data type [1, 2]: | ||
|
||
<table> | ||
<tr> | ||
<th>Format Name</th> | ||
<th>Element Data type</th> | ||
<th>Element Bits</th> | ||
<th>Scaling Block Size</th> | ||
<th>Scale Data Type</th> | ||
<th>Scale Bits</th> | ||
</tr> | ||
<tr> | ||
<td rowspan="2">MXFP8</td> | ||
<td>FP8 (E5M2)</td> | ||
<td rowspan="2">8</td> | ||
<td rowspan="2">32</td> | ||
<td rowspan="2">E8M0</td> | ||
<td rowspan="2">8</td> | ||
</tr> | ||
<tr> | ||
<td>FP8 (E4M3)</td> | ||
</tr> | ||
<tr> | ||
<td rowspan="2">MXFP6</td> | ||
<td>FP6 (E3M2)</td> | ||
<td rowspan="2">6</td> | ||
<td rowspan="2">32</td> | ||
<td rowspan="2">E8M0</td> | ||
<td rowspan="2">8</td> | ||
</tr> | ||
<tr> | ||
<td>FP6 (E2M3)</td> | ||
</tr> | ||
<tr> | ||
<td>MXFP4</td> | ||
<td>FP4 (E2M1)</td> | ||
<td>4</td> | ||
<td>32</td> | ||
<td>E8M0</td> | ||
<td>8</td> | ||
</tr> | ||
<tr> | ||
<td>MXINT8</td> | ||
<td>INT8</td> | ||
<td>8</td> | ||
<td>32</td> | ||
<td>E8M0</td> | ||
<td>8</td> | ||
</tr> | ||
</table> | ||
|
||
|
||
At an equivalent accuracy level, the MX data type demonstrates the ability to occupy a smaller area and incur lower energy costs for multiply-accumulate compared to other conventional data types on the same silicon [1]. | ||
|
||
Neural Compressor seamlessly applies the MX data type to post-training quantization, offering meticulously crafted recipes to empower users to quantize LLMs without sacrificing accuracy. The workflow is shown as below. | ||
|
||
<a target="_blank" href="./imgs/mx_workflow.png" text-align:left> | ||
<left> | ||
<img src="./imgs/mx_workflow.png" alt="Workflow of MX Quant (source [3])" height=120> | ||
</left> | ||
</a> | ||
|
||
The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly. It adapts a granularity falling between per-channel and per-tensor to balance accuracy and memory consumption. | ||
|
||
| | MX Format | INT8 | FP8 | | ||
|------------|--------------|------------|------------| | ||
| Scale | $2^{exp}$ | $\frac{MAX}{amax}$ | $\frac{MAX}{amax}$ | | ||
| Zero point | 0 (None) | $2^{bits - 1}$ or $-min * scale$ | 0 (None) | | ||
| Granularity | per-block (default blocksize is 32) | per-channel or per-tensor | per-channel or per-tensor | | ||
|
||
The exponent (exp) is equal to torch.floor(torch.log2(amax)), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor. | ||
|
||
|
||
## Get Started with Microscaling Quantization API | ||
|
||
To get a model quantized with Microscaling Data Types, users can use the Microscaling Quantization API as follows. | ||
|
||
```python | ||
from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert | ||
|
||
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq) | ||
user_model = prepare(model=user_model, quant_config=quant_config) | ||
user_model = convert(model=user_model) | ||
``` | ||
|
||
## Examples | ||
|
||
- PyTorch [huggingface models](/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx) | ||
|
||
|
||
## Reference | ||
|
||
[1]: Darvish Rouhani, Bita, et al. "Pushing the limits of narrow precision inferencing at cloud scale with microsoft floating point." Advances in neural information processing systems 33 (2020): 10271-10281 | ||
|
||
[2]: OCP Microscaling Formats (MX) Specification | ||
|
||
[3]: Rouhani, Bita Darvish, et al. "Microscaling Data Formats for Deep Learning." arXiv preprint arXiv:2310.10537 (2023). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
PyTorch Mixed Precision | ||
======================================== | ||
|
||
1. [Introduction](#introduction) | ||
2. [Mixed Precision Support Matrix](#mixed-precision-support-matrix) | ||
3. [Get Started](#get-start) | ||
4. [Examples](#examples) | ||
|
||
## Introduction | ||
|
||
The recent growth of Deep Learning has driven the development of more complex models that require significantly more compute and memory capabilities. Several low precision numeric formats have been proposed to address the problem. Google's [bfloat16](https://cloud.google.com/tpu/docs/bfloat16) and the [FP16: IEEE](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) half-precision format are two of the most widely used sixteen bit formats. [Mixed precision](https://arxiv.org/abs/1710.03740) training and inference using low precision formats have been developed to reduce compute and bandwidth requirements. | ||
|
||
The 3rd Gen Intel® Xeon® Scalable processor (codenamed Cooper Lake), featuring Intel® Deep Learning Boost, is the first general-purpose x86 CPU to support the bfloat16 format. Specifically, three new bfloat16 instructions are added as a part of the AVX512_BF16 extension within Intel Deep Learning Boost: VCVTNE2PS2BF16, VCVTNEPS2BF16, and VDPBF16PS. The first two instructions allow converting to and from bfloat16 data type, while the last one performs a dot product of bfloat16 pairs. Further details can be found in the [hardware numerics document](https://www.intel.com/content/www/us/en/developer/articles/technical/intel-deep-learning-boost-new-instruction-bfloat16.html) published by Intel. | ||
|
||
The 4th Gen Intel® Xeon® Scalable processor supports FP16 instruction set architecture (ISA) for Intel® | ||
Advanced Vector Extensions 512 (Intel® AVX-512). The new ISA supports a wide range of general-purpose numeric | ||
operations for 16-bit half-precision IEEE-754 floating-point and complements the existing 32-bit and 64-bit floating-point instructions already available in the Intel Xeon processor based products. Further details can be found in the [hardware numerics document](https://www.intel.com/content/www/us/en/content-details/669773/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide.html) published by Intel. | ||
|
||
<p align="center" width="100%"> | ||
<img src="./imgs/data_format.png" alt="Architecture" height=230> | ||
</p> | ||
|
||
## Mixed Precision Support Matrix | ||
|
||
<table class="center"> | ||
<thead> | ||
<tr> | ||
<th>Framework</th> | ||
<th>Backend</th> | ||
<th>Backend Library</th> | ||
<th>Backend Value</th> | ||
<th>Support Device(cpu as default)</th> | ||
<th>Support BF16</th> | ||
<th>Support FP16</th> | ||
</tr> | ||
</thead> | ||
<tbody> | ||
<tr> | ||
<td rowspan="1" align="left">PyTorch</td> | ||
<td align="left">FX</td> | ||
<td align="left">FBGEMM</td> | ||
<td align="left">"default"</td> | ||
<td align="left">cpu</td> | ||
<td align="left">✔</td> | ||
<td align="left">✔</td> | ||
</tr> | ||
</tbody> | ||
</table> | ||
|
||
|
||
### Hardware and Software requests for **BF16** | ||
- PyTorch | ||
1. Hardware: CPU supports `avx512_bf16` instruction set. | ||
2. Software: torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html). | ||
|
||
|
||
### Hardware and Software requests for **FP16** | ||
- PyTorch | ||
1. Hardware: CPU supports `avx512_fp16` instruction set. | ||
2. Software: torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html). | ||
|
||
|
||
### Accuracy-driven mixed precision | ||
BF16/FP16 conversion may lead to accuracy drop. Intel® Neural Compressor provides an accuracy-driven tuning function to reduce accuracy loss, | ||
which could fallback converted ops to FP32, if set in config, to get better accuracy. To enable this function, users only to provide | ||
`eval_fn` and `eval_args` for `autotune`. | ||
To be noticed, IPEX backend doesn't support accuracy-driven mixed precision. | ||
|
||
## Get Started with autotune API | ||
|
||
To get a bf16/fp16 model, users can use the `autotune` interface with `MixPrecisionConfig` as follows. | ||
|
||
- BF16: | ||
|
||
```python | ||
from neural_compressor.torch.quantization import MixPrecisionConfig, TuningConfig, autotune | ||
|
||
def eval_acc_fn(model): | ||
...... | ||
return acc | ||
|
||
# modules might be fallback to fp32 to get better accuracy | ||
custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["bf16", "fp32"])], max_trials=3) | ||
best_model = autotune(model=build_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) | ||
``` | ||
|
||
- FP16: | ||
|
||
```python | ||
from neural_compressor.torch.quantization import MixPrecisionConfig, TuningConfig, autotune | ||
|
||
def eval_acc_fn(model): | ||
...... | ||
return acc | ||
|
||
# modules might be fallback to fp32 to get better accuracy | ||
custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "fp32"])], max_trials=3) | ||
best_model = autotune(model=build_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) | ||
``` | ||
|
||
## Examples | ||
|
||
Example will be added later. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
PyTorch Smooth Quantization | ||
======================================== | ||
|
||
1. [Introduction](#Introduction) | ||
2. [Usage](#Usage) | ||
3. [Validated Models](#Validated-Models) | ||
4. [Supported Framework Matrix](#Supported-Framework-Matrix) | ||
|
||
|
||
## Introduction | ||
Quantization is a common compression operation to reduce memory and accelerate inference by converting the floating point matrix to an integer matrix. For large language models (LLMs) with gigantic parameters, the systematic outliers make quantification of activations difficult. [SmoothQuant](https://arxiv.org/abs/2211.10438), a training free post-training quantization (PTQ) solution, offline migrates this difficulty from activations to weights with a mathematically equivalent transformation. | ||
|
||
|
||
## Usage | ||
### Fixed Alpha | ||
To set a fixed alpha for the entire model, users can follow this example: | ||
|
||
```python | ||
from neural_compressor.torch.quantization import SmoothQuantConfig, convert, prepare | ||
|
||
|
||
def run_fn(model): | ||
model(example_inputs) | ||
|
||
|
||
quant_config = SmoothQuantConfig(alpha=0.5) | ||
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) | ||
run_fn(prepared_model) | ||
q_model = convert(prepared_model) | ||
``` | ||
`SmoothQuantConfig` description: | ||
|
||
`alpha`: a smooth factor to calculate the conversion per-channel scale and balance the quantization difficulty of activation and weight. Float value, default is 0.5. | ||
|
||
> **Note:** Alpha="auto" and alpha auto-tuning was supported in old API, please stay tuned for the new API's support for auto alpha. | ||
### Specify Quantization Rules | ||
Intel(R) Neural Compressor support specify quantization rules by operator type for Smooth Quantization. Users can use `set_local` to fallback op type in `SmoothQuantConfig` to achieve the above purpose. | ||
|
||
Here we don't quantize `Linear` layers. | ||
```python | ||
# fallback by op_type | ||
quant_config.set_local("Linear", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32")) | ||
prepared_model = prepare(model, quant_config=quant_config, example_inputs=example_inputs) | ||
run_fn(prepared_model) | ||
q_model = convert(prepared_model) | ||
``` | ||
|
||
To get more information, please refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm). | ||
|
||
|
||
## Validated Models | ||
Neural Compressor: 2.1 | ||
|
||
IPEX (Intel Extension for PyTorch): 2.0/2.1 | ||
|
||
Dataset: lambada_openai | ||
|
||
Task: text-generation provided by [ITREX](https://github.com/intel/intel-extension-for-transformers/tree/main/examples/huggingface/pytorch/text-generation/quantization) | ||
|
||
alpha [0.4, 0.6] is sweet spot region in SmoothQuant paper. | ||
|
||
A list of models that achieved a <1% accuracy drop is shown below. | ||
|
||
| Model/Last token accuracy | FP32 Accuracy | INT8 (w/ SmoothQuant) | Notes | | ||
|:----------:|:------:|:------:|-----------------------------------| | ||
| bigscience/bloom-560m | 0.354 | 0.3542 | alpha=0.5, Ipex 2.1 | | ||
| bigscience/bloom-1b7 | 0.4634 | 0.4936 | alpha=0.5, Ipex 2.0 | | ||
| bigscience/bloom-3b | 0.518 | 0.5185 | alpha=0.8, Ipex 2.1 | | ||
| bigscience/bloom-7b1 | 0.5764 | 0.5977 | alpha=0.5, Ipex 2.0 | | ||
| bigscience/bloomz-560m | 0.3947 | 0.3930 | alpha=0.8, Ipex 2.1 | | ||
| bigscience/bloomz-1b7 | 0.4828 | 0.4906 | alpha=0.5, Ipex 2.1 | | ||
| bigscience/bloomz-3b | 0.5018 | 0.4980 | alpha=0.5, Ipex 2.1 | | ||
| bigscience/bloomz-7b1 | 0.5593 | 0.5552 | alpha=0.5, Ipex 2.1 | | ||
| facebook/opt-125m | 0.379 | 0.3757 | alpha=0.5, Ipex 2.1 | | ||
| facebook/opt-350m | 0.4516 | 0.4533 | alpha=0.8, Ipex 2.1 | | ||
| facebook/opt-1.3b | 0.5789 | 0.5742 | alpha=0.8, Ipex 2.0 | | ||
| facebook/opt-2.7b | 0.6365 | 0.6404 | alpha=0.5, Ipex 2.0 | | ||
| facebook/opt-6.7b | 0.6769 | 0.6804 | alpha=0.5, Ipex 2.0 | | ||
| facebook/opt-13b | 0.6872 | 0.6814 | alpha=0.5, Ipex 2.1 | | ||
| facebook/opt-30b | 0.7149 | 0.7128 | alpha=0.5, Ipex 2.1 | | ||
| facebook/opt-66b | 0.7398 | 0.7326 | alpha=0.5, Ipex 2.1 | | ||
| LLaMa-7b | 0.7361 | 0.7357 | alpha=0.8, Ipex 2.1 | | ||
| LLaMa-13b | 0.7627 | 0.7590 | alpha=0.7, Ipex 2.1 | | ||
| LLaMa-30b | 0.7759 | 0.7840 | alpha=0.7, Ipex 2.1 | | ||
| LLaMa-65b | 0.7908 | 0.7957 | alpha=0.9, Ipex 2.1 | | ||
| EleutherAI/gpt-j-6B* | 0.6831 | 0.6821 | alpha=1.0, Ipex 2.1 | | ||
| MBZUAI/LaMini-GPT-124m | 0.3804 | 0.3887 | alpha=0.5, Ipex 2.1 | | ||
| MBZUAI/LaMini-GPT-774m | 0.5048 | 0.5057 | alpha=0.5, Ipex 2.1 | | ||
| MBZUAI/LaMini-GPT-1.5b | 0.5443 | 0.5436 | alpha=0.5, Ipex 2.1 | | ||
| mosaicml/mpt-7b-chat | 0.655 | 0.6499 | alpha=0.7, Ipex 2.1 | | ||
| stabilityai/stablelm-base-alpha-3b | 0.4172 | 0.4149 | alpha=0.6, Ipex 2.1 | | ||
| togethercomputer/RedPajama-INCITE-Base-3B-v1 | 0.6542 | 0.6735 | alpha=0.5, Ipex 2.1 | | ||
| togethercomputer/RedPajama-INCITE-Chat-3B-v1* | 0.6718 | 0.6740 | alpha=0.5, Ipex 2.0 | | ||
| togethercomputer/RedPajama-INCITE-Instruct-3B-v1* | 0.6569 | 0.6621 | alpha=0.5, Ipex 2.0 | | ||
| togethercomputer/RedPajama-INCITE-Base-7B-v0.1* | 0.7143 | 0.7221 | alpha=0.5, Ipex 2.0 | | ||
| togethercomputer/RedPajama-INCITE-Instruct-7B-v0.1* | 0.6895 | 0.6953 | alpha=0.5, Ipex 2.0 | | ||
| databricks/dolly-v1-6b* | 0.6866 | 0.6895 | alpha=0.8, Ipex 2.1 | | ||
| databricks/dolly-v2-3b* | 0.6297 | 0.6247 | alpha=0.5, Ipex 2.1 | | ||
| tiiuae/falcon-7b-instruct | 0.6437 | 0.6392 | alpha=0.7, Pytorch | | ||
|
||
Please refer to the step-by-step [instruction](../../examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex/README.md) for details. | ||
|
||
Please note that for models with asterisk(*), we have set all add ops to FP32 during quantization step to achieve desirable results. | ||
|
||
|
||
## Supported Framework Matrix | ||
|
||
| Framework | Alpha | Folding | | ||
|:---------:|--------------|------------| | ||
| PyTorch | [0-1] | False | | ||
| IPEX | [0-1] | True / False(Version>2.1) | |
Oops, something went wrong.