This repository contains the codes for reproducing our work: "Quantization Variation: A New Perspective on Training Transformers with Low-Bit Precision", published in Transactions on Machine Learning Research (TMLR).
In this paper, we identify the difficulty of transformer low-bit quantization-aware training on its unique variation behaviors, which significantly differ from ConvNets. Based on comprehensive quantitative analysis, we observe variation in three hierarchies: various module quantization sensitivities, outliers in static weight and activation distribution, and oscillation in dynamic parameter fluctuations. These variations of transformers bring instability to the quantization-aware training (QAT) and negatively influence the performance. We explore the best practices to alleviate the variation's influence during low-bit transformer QAT and propose a variation-aware quantization scheme. We extensively verify and show our scheme can alleviate the variation and improve the performance of transformers across various models and tasks. Our solution substantially improves the 2-bit Swin-T, achieving a 3.35% accuracy improvement over previous state-of-the-art methods on ImageNet-1K.
If you find our code useful for your research, please consider citing:
@article{
huang2024quantization,
title={Quantization Variation: A New Perspective on Training Transformers with Low-Bit Precision},
author={Xijie Huang, Zhiqiang Shen, Pingcheng Dong, Kwang-Ting Cheng},
journal={Transactions on Machine Learning Research},
year={2024},
url={https://openreview.net/forum?id=MHfoA0Qf6g}
}
- PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2
-
Install PyTorch and ImageNet dataset following the official PyTorch ImageNet training code.
-
Download the soft label following FKD and unzip it. We provide multiple types of soft labels, and we recommend to use Marginal Smoothing Top-5 (500-crop).
- Download full-precision pre-trained weights via link provided in Models.
- (Optional) Train your own full-precision baseline model, please check
./fp_pretrained
.
- W4A4 DeiT-T Quantization with multi-processing distributed training on a single node with multiple GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_VVTQ.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
--model deit_tiny_patch16_224_quant --batch-size 512 --lr 5e-4 \
--warmup-epochs 0 --min-lr 0 --wbits 4 --abits 4 --reg \
--softlabel_path ./FKD_soft_label_500_crops_marginal_smoothing_k_5 \
--finetune [path to full precision baseline model] \
--save_checkpoint_path ./DeiT-T-4bit --log ./log/DeiT-T-4bit.log\
--data [imagenet-folder with train and val folders]
CUDA_VISIBLE_DEVICES=0 python train_VVTQ.py \
--model deit_tiny_patch16_224_quant --batch-size 512 --wbits 4 --abits 4 \
--resume [path to W4A4 DeiT-T ckpt] --evaluate --log ./log/DeiT-T-W4A4.log \
--data [imagenet-folder with train and val folders]
Model | W bits | A bits | accuracy (Top-1) | weights | logs |
---|---|---|---|---|---|
DeiT-T |
32 | 32 | 73.75 | link | - |
DeiT-T |
4 | 4 | 74.71 | link | link |
DeiT-T |
3 | 3 | 71.22 | link | link |
DeiT-T |
2 | 2 | 59.73 | link | link |
SReT-T |
32 | 32 | 75.81 | link | - |
SReT-T |
4 | 4 | 76.99 | link | link |
SReT-T |
3 | 3 | 75.40 | link | link |
SReT-T |
2 | 2 | 67.53 | link | link |
Swin-T |
32 | 32 | 81.0 | link | - |
Swin-T |
4 | 4 | 82.42 | link | link |
Swin-T |
3 | 3 | 81.37 | link | link |
Swin-T |
2 | 2 | 77.66 | link | link |
This repo benefits from FKD and LSQuantization. Thanks for their wonderful works!
If you have any questions, feel free to contact Xijie Huang (xhuangbs AT connect.ust.hk, huangxijie1108 AT gmail.com)