Skip to content

Latest commit

 

History

History
102 lines (72 loc) · 6.48 KB

README.md

File metadata and controls

102 lines (72 loc) · 6.48 KB

Quantization Variation

This repository contains the codes for reproducing our work: "Quantization Variation: A New Perspective on Training Transformers with Low-Bit Precision", published in Transactions on Machine Learning Research (TMLR).

Abstract

In this paper, we identify the difficulty of transformer low-bit quantization-aware training on its unique variation behaviors, which significantly differ from ConvNets. Based on comprehensive quantitative analysis, we observe variation in three hierarchies: various module quantization sensitivities, outliers in static weight and activation distribution, and oscillation in dynamic parameter fluctuations. These variations of transformers bring instability to the quantization-aware training (QAT) and negatively influence the performance. We explore the best practices to alleviate the variation's influence during low-bit transformer QAT and propose a variation-aware quantization scheme. We extensively verify and show our scheme can alleviate the variation and improve the performance of transformers across various models and tasks. Our solution substantially improves the 2-bit Swin-T, achieving a 3.35% accuracy improvement over previous state-of-the-art methods on ImageNet-1K.

Citation

If you find our code useful for your research, please consider citing:

@article{
    huang2024quantization,
    title={Quantization Variation: A New Perspective on Training Transformers with Low-Bit Precision},
    author={Xijie Huang, Zhiqiang Shen, Pingcheng Dong, Kwang-Ting Cheng},
    journal={Transactions on Machine Learning Research},
    year={2024},
    url={https://openreview.net/forum?id=MHfoA0Qf6g}
}

Preparation

Requirements

  • PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data and Soft Label

Run

Preparing for full-precision baseline model

  • Download full-precision pre-trained weights via link provided in Models.
  • (Optional) Train your own full-precision baseline model, please check ./fp_pretrained.

Quantization-aware training

  • W4A4 DeiT-T Quantization with multi-processing distributed training on a single node with multiple GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_VVTQ.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
--model deit_tiny_patch16_224_quant --batch-size 512 --lr 5e-4 \
--warmup-epochs 0 --min-lr 0 --wbits 4 --abits 4 --reg \
--softlabel_path ./FKD_soft_label_500_crops_marginal_smoothing_k_5 \
--finetune [path to full precision baseline model] \
--save_checkpoint_path ./DeiT-T-4bit --log ./log/DeiT-T-4bit.log\
--data [imagenet-folder with train and val folders]

Evaluation

CUDA_VISIBLE_DEVICES=0 python train_VVTQ.py \
--model deit_tiny_patch16_224_quant --batch-size 512 --wbits 4 --abits 4 \
--resume [path to W4A4 DeiT-T ckpt] --evaluate --log ./log/DeiT-T-W4A4.log \
--data [imagenet-folder with train and val folders]

Models

Model W bits A bits accuracy (Top-1) weights logs
DeiT-T 32 32 73.75 link -
DeiT-T 4 4 74.71 link link
DeiT-T 3 3 71.22 link link
DeiT-T 2 2 59.73 link link
SReT-T 32 32 75.81 link -
SReT-T 4 4 76.99 link link
SReT-T 3 3 75.40 link link
SReT-T 2 2 67.53 link link
Swin-T 32 32 81.0 link -
Swin-T 4 4 82.42 link link
Swin-T 3 3 81.37 link link
Swin-T 2 2 77.66 link link

Acknowledgement

This repo benefits from FKD and LSQuantization. Thanks for their wonderful works!

If you have any questions, feel free to contact Xijie Huang (xhuangbs AT connect.ust.hk, huangxijie1108 AT gmail.com)