[ACL 2024] BitDistiller: Unleashing the Potential of Sub-4-Bit LLMs via Self-Distillation [paper]
Implementing efficient sub-4-bit weight quantization (3 / 2 bits) in LLMs through advanced QAT-based Self-Distillation techniques.
- [2024/05] 🔥 BitDistiller has been accepted to ACL main 2024!
-
python 3.9, pytorch >= 1.13
-
pip install -r requirement.txt
(You may need to change the version of transformers according to the model config)
Our results is running by following 3 steps:
-
Determine the type of quantization: use
nf3
for 3 bits andint
for 2 bits. Setw_bit
andquant_type
accordingly. -
Perform clipping before training and save the clipping values using dump_clip (see
quantization/autoclip.py
).
This step can match or surpass the low-bit PTQ quantization results of GPTQ and AWQ.
- For QAT, create data using the Teacher Model (BF16). The data varies depending on the model (see
data/generation
).
- Detailed procedure available in
train/
LLaMA-2
-
Get the Clipping result
cd BitDistiller/quantization CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset pile --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/hf-llama2-7b/int2-g128.pt
-
Get the Teacher Generation Data (Using vllm would be much faster)
# vllm python generate_vllm.py --base_model <model_path> --dataset_name wikitext --out_path ./datasets/hf-llama-2-7b/ --max_sample 3000 python generate_vllm.py --base_model <model_path> --dataset_name alpaca --out_path ./datasets/hf-llama-2-7b/ --max_sample 5000 # change to path in .py python mix_data.py
# torchrun cd BitDistiller/data/generation bash generate.sh <model_path> wikitext ../datasets/hf-llama-2-7b/ 16 3000 bash generate.sh <model_path> alpaca ../datasets/hf-llama-2-7b/ 16 5000 # change to path in .py python mix_data.py
-
Run KD-base QAT
# Specify the pre-trained model path # Specify the num_gpus and batch_size according to your GPU devices # Specify the clipping cache path to the --clip cd train bash train.sh ../data/datasets/hf-llama-2-7b/mix_wiki_alpaca_8000.json ./ckpts/hf-llama-2-7b/int2-g128/ ./logs/hf-llama-2-7b/int2-g128/ 4
WizardCoder
-
Get the Clipping result
cd BitDistiller/quantization CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset code --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/WizardCoder-7B/int2-g128.pt
-
Get the Teacher Generation Data
# vllm python generate_vllm.py --base_model <model_path> --dataset_name code --out_path ./datasets/WizardCoder-7b/ --max_sample 3000
cd BitDistiller/data/generation bash generate.sh /root/WizardCoder-Python-7B/ code ../datasets/WizardCoder-7b/ 16 3000
-
Run KD-base QAT
# Specify the pre-trained model path # Specify the num_gpus and batch_size according to your GPU devices # Specify the clipping cache path to the --clip cd train bash train.sh ../data/datasets/WizardCoder-7b/code_T0.7_N1024_S42_3000.json ./ckpts/WizardCoder-7b/int2-g128/ ./logs/WizardCoder-7b/int2-g128/ 2
MetaMath
-
Get the Clipping result
cd BitDistiller/quantization CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset gsm8k --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/MetaMath-7B/int2-g128.pt
-
Get the Teacher Generation Data
# vllm python generate_vllm.py --base_model <model_path> --dataset_name math --out_path ./datasets/MetaMath-7B/ --max_sample 3000
cd BitDistiller/data/generation bash generate.sh /root/MetaMath-7B-V1.0/ math ../datasets/MetaMath-7B/ 16 3000
-
Run KD-base QAT
# Specify the pre-trained model path # Specify the num_gpus and batch_size according to your GPU devices # Specify the clipping cache path to the --clip cd train bash train.sh ../data/datasets/MetaMath-7B/math_T0.7_N1024_S42_3000.json ./ckpts/MetaMath-7b/int2-g128/ ./logs/MetaMath-7b/int2-g128/ 2
LLaMA-2
- Test PPL on WikiText-2
cd test/general python wiki_ppl.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --quant_type int --bits 2 --group_size 128
- Test MMLU
CUDA_VISIBLE_DEVICES=0 python llm_eval.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --eval_tasks hendrycksTest-* --test_set --bits 2 --group_size 128 --quant_type int --num_fewshot 5
- Test Common-sense QA Tasks
CUDA_VISIBLE_DEVICES=0 python llm_eval.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --eval_tasks arc_challenge,winogrande,hellaswag,piqa --test_set --bits 2 --group_size 128 --quant_type int --num_fewshot 0
WizardCoder
-
Install the environment according to the instructions of HumanEval,
-
Example script:
cd test/humaneval bash gen_preds.sh [checkpoint_path] ./preds/7b/int2-g128/
MetaMath
-
Example script:
cd test/gsm8k bash test.sh ../../train/ckpts/MetaMath-7b/int2-g128/ ./preds/7b/int2-g128/
Please see inference/
If you find BitDistiller useful or relevant to your research, please kindly cite our paper:
@misc{du2024bitdistiller,
title={BitDistiller: Unleashing the Potential of Sub-4-Bit LLMs via Self-Distillation},
author={Dayou Du and Yijia Zhang and Shijie Cao and Jiaqi Guo and Ting Cao and Xiaowen Chu and Ningyi Xu},
year={2024},
eprint={2402.10631},
archivePrefix={arXiv},
primaryClass={cs.CL}
}