Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Unify the woq config weight_dtype for int4 and fp4 on different devic…
Browse files Browse the repository at this point in the history
…es (#1594)

Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
PenghuiCheng authored Jun 7, 2024
1 parent a5d9129 commit 8722443
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 75 deletions.
2 changes: 1 addition & 1 deletion docs/weightonlyquant.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ quantization_config = AutoRoundConfig(
max_input_length=2048,
compute_dtype="fp16",
scale_dtype="fp16",
weight_dtype="int4_fullrange",
weight_dtype="int4", # int4 == int4_clip
calib_iters=2,
calib_len=32,
nsamples=2,
Expand Down
6 changes: 3 additions & 3 deletions examples/.config/pytorch_optimize.json
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,7 @@
"params": {
"model": "mistralai/Mistral-7B-v0.1",
"output_dir": "saved_results",
"weight_dtype": "int4_fullrange"
"weight_dtype": "int4"
}
},
"benchmark": {
Expand All @@ -1764,7 +1764,7 @@
"params": {
"model": "meta-llama/Llama-2-7b-hf",
"output_dir": "saved_results",
"weight_dtype": "int4_fullrange"
"weight_dtype": "int4"
}
},
"benchmark": {
Expand All @@ -1784,7 +1784,7 @@
"params": {
"model": "Qwen/Qwen-7B-Chat",
"output_dir": "saved_results",
"weight_dtype": "int4_fullrange"
"weight_dtype": "int4"
}
},
"benchmark": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@
default="int8",
choices=[
"int8",
"int4", # int4 == int4_clip
"int4_clip",
"int4_fullrange",
"fp4", # fp4 == fp4_e2m1
"fp4_e2m1_bnb",
"fp4_e2m1",
"nf4",
"fp8", # fp8 == fp8_e4m3
"fp8_e5m2",
"fp8_e4m3",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -96,7 +96,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -135,7 +135,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -149,7 +149,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -188,7 +188,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -202,7 +202,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -244,7 +244,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -258,7 +258,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -300,7 +300,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -314,7 +314,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -353,7 +353,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -367,7 +367,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -406,7 +406,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -420,7 +420,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -459,7 +459,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--max_input_length 2048 \
--scheme sym \
--group_size 32 \
Expand All @@ -472,7 +472,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -511,7 +511,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -525,7 +525,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -564,7 +564,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -578,7 +578,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -618,7 +618,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -632,7 +632,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -671,7 +671,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--max_input_length 2048 \
--scheme asym \
--group_size 32 \
Expand All @@ -684,7 +684,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -723,7 +723,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--max_input_length 2048 \
--scheme sym \
--group_size 32 \
Expand All @@ -737,7 +737,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -776,7 +776,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -790,7 +790,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -829,7 +829,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--max_input_length 2048 \
--scheme asym \
--group_size 32 \
Expand All @@ -842,7 +842,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -881,7 +881,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--desc_act \
--max_input_length 2048 \
--scheme sym \
Expand All @@ -895,7 +895,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down Expand Up @@ -934,7 +934,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo GPTQ \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--max_input_length 2048 \
--scheme sym \
--group_size 128 \
Expand All @@ -947,7 +947,7 @@ python run_generation_cpu_woq.py \
--woq \
--woq_algo AutoRound \
--bits 4 \
--weight_dtype int4_clip \
--weight_dtype int4 \
--calib_iters 200 \
--scheme asym \
--group_size 128 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
default="int8",
choices=[
"int8",
"int4", # int4 == int4_clip
"int4_clip",
"int4_fullrange",
"fp4", # fp4 == fp4_e2m1
"fp4_e2m1_bnb",
"fp4_e2m1",
"nf4",
"fp8", # fp8 == fp8_e4m3
"fp8_e5m2",
"fp8_e4m3",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@
parser.add_argument("--woq", action="store_true")
parser.add_argument("--woq_algo", default="Rtn", choices=['Rtn', 'GPTQ', 'AutoRound'],
help="Weight-only parameter.")
parser.add_argument("--weight_dtype", type=str, default="int4_fullrange",
choices=["int4_fullrange"])
parser.add_argument("--weight_dtype", type=str, default="int4",
choices=[
"int4", # int4 == int4_fullrange
"int4_fullrange",
]
)
parser.add_argument("--group_size", type=int, default=128)
parser.add_argument("--scheme", default="sym")
parser.add_argument("--woq_enable_mse_search", action="store_true")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function init_params {
approach="PostTrainingStatic"
script="run_generation_sq.py"
alpha=0.5
weight_dtype="int4_clip"
weight_dtype="int4"
scheme="asym"
for var in "$@"
do
Expand Down
Loading

0 comments on commit 8722443

Please sign in to comment.