From 2bb257e71353d87414ff7e410ca35bce5cc3dbc7 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 10 Oct 2024 19:27:11 +0800 Subject: [PATCH] Add woq examples (#1982) Signed-off-by: Kaihui-intel Signed-off-by: Sun, Xuehao Co-authored-by: Sun, Xuehao --- examples/.config/model_params_pytorch_3x.json | 28 +++ .../quantization/weight_only/README.md | 73 +++++-- .../quantization/weight_only/run_benchmark.sh | 59 +++--- .../weight_only/run_clm_no_trainer.py | 181 +++++++++++++++++- .../quantization/weight_only/run_quant.sh | 13 ++ 5 files changed, 309 insertions(+), 45 deletions(-) diff --git a/examples/.config/model_params_pytorch_3x.json b/examples/.config/model_params_pytorch_3x.json index c3ae3f6b5be..809b898d5e3 100644 --- a/examples/.config/model_params_pytorch_3x.json +++ b/examples/.config/model_params_pytorch_3x.json @@ -84,6 +84,34 @@ "main_script": "run_clm_no_trainer.py", "batch_size": 8 }, + "gpt_j_woq_awq_int4":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/weight_only", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, + "opt_125m_woq_awq_int4":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/weight_only", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, + "opt_125m_woq_autoround_int4":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/weight_only", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, + "opt_125m_woq_autotune_int4":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/weight_only", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, "gpt_j_ipex":{ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/static_quant/ipex", "dataset_location": "", diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/README.md index 889d7b42682..0519b490ff7 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/README.md +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/README.md @@ -35,9 +35,8 @@ python run_clm_no_trainer.py \ --woq_group_size 128 \ --gptq_max_seq_length 2048 \ --gptq_use_max_length \ - --accuracy \ - --tasks "lambada_openai" \ - --double_quant_type "BNB_NF4" + --double_quant_type "BNB_NF4" \ + --output_dir saved_results # "--woq_algo RTN" is used to enable RTN algorithms python run_clm_no_trainer.py \ @@ -48,9 +47,38 @@ python run_clm_no_trainer.py \ --woq_bits 4 \ --woq_scheme asym \ --woq_group_size 128 \ + --double_quant_type "BNB_NF4" + --output_dir saved_results + +# "--woq_algo AWQ" is used to enable AWQ algorithms +python run_clm_no_trainer.py \ + --model EleutherAI/gpt-j-6B \ + --dataset NeelNanda/pile-10k \ + --quantize \ + --woq_algo AWQ \ + --woq_bits 4 \ + --woq_scheme asym \ + --woq_group_size 128 \ + --calib_iters 128 + +# "--woq_algo AutoRound" is used to enable AutoRound algorithms +python run_clm_no_trainer.py \ + --model EleutherAI/gpt-j-6B \ + --dataset NeelNanda/pile-10k \ + --quantize \ + --woq_algo AutoRound \ + --woq_bits 4 \ + --woq_scheme asym \ + --woq_group_size 128 + +# "--accuracy" for eval +python run_clm_no_trainer.py \ + --model EleutherAI/gpt-j-6B \ + --dataset NeelNanda/pile-10k \ + --int8 \ --accuracy \ --tasks "lambada_openai" \ - --double_quant_type "BNB_NF4" + --output_dir saved_results ``` **Notes**: Weight-only quantization based on fake quantization is previewly supported and supports RTN, GPTQ[1], AWQ[2], TEQ algorithms. For more details, please refer to [link](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md). Our GPTQ API support various CLMs including GPTJ, OPTs, Blooms, Llamas, Falcons, MPTs, ChatGLMs, etc. Simply replace the "--model" argument with other models to quantize different CLMs with GPTQ. @@ -72,8 +100,6 @@ python run_clm_no_trainer.py \ --woq_group_size 128 \ --gptq_max_seq_length 2048 \ --gptq_use_max_length \ - --accuracy \ - --tasks "lambada_openai" \ --double_quant_type "BNB_NF4" # "--woq_algo RTN" is used to enable RTN algorithms @@ -85,13 +111,40 @@ python run_clm_no_trainer.py \ --woq_bits 4 \ --woq_scheme asym \ --woq_group_size 128 \ + --double_quant_type "BNB_NF4" + +# "--woq_algo AWQ" is used to enable AWQ algorithms +python run_clm_no_trainer.py \ + --model facebook/opt-125m \ + --dataset NeelNanda/pile-10k \ + --quantize \ + --woq_algo AWQ \ + --woq_bits 4 \ + --woq_scheme asym \ + --woq_group_size 128 \ + --calib_iters 128 + +# "--woq_algo AutoRound" is used to enable AutoRound algorithms +python run_clm_no_trainer.py \ + --model facebook/opt-125m \ + --dataset NeelNanda/pile-10k \ + --quantize \ + --woq_algo AutoRound \ + --woq_bits 4 \ + --woq_scheme asym \ + --woq_group_size 128 + +# "--accuracy" for eval +python run_clm_no_trainer.py \ + --model facebook/opt-125m \ + --dataset NeelNanda/pile-10k \ + --int8 \ --accuracy \ --tasks "lambada_openai" \ - --double_quant_type "BNB_NF4" + --output_dir saved_results ``` ### LLAMA2-7b/13b/70b ->Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy. #### Quantization ```bash @@ -107,8 +160,6 @@ python run_clm_no_trainer.py \ --woq_group_size 128 \ --gptq_max_seq_length 2048 \ --gptq_use_max_length \ - --accuracy \ - --tasks "lambada_openai" \ --double_quant_type "BNB_NF4" # "--woq_algo RTN" is used to enable RTN algorithms @@ -120,8 +171,6 @@ python run_clm_no_trainer.py \ --woq_bits 4 \ --woq_scheme asym \ --woq_group_size 128 \ - --accuracy \ - --tasks "lambada_openai" \ --double_quant_type "BNB_NF4" ``` diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_benchmark.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_benchmark.sh index 9e1d766128e..6c84e27ce88 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_benchmark.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_benchmark.sh @@ -70,58 +70,59 @@ function run_benchmark { fi echo $extra_cmd - if [ "${topology}" = "opt_125m_woq_gptq_int4" ]; then + if [ "${topology}" = "opt_125m_woq_gptq_int4" ]; then model_name_or_path="facebook/opt-125m" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" elif [ "${topology}" = "opt_125m_woq_gptq_int4_dq_bnb" ]; then model_name_or_path="facebook/opt-125m" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "opt_125m_woq_gptq_int4_dq_ggml" ]; then model_name_or_path="facebook/opt-125m" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length --gptq_percdamp 0.1 --gptq_actorder" - extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" elif [ "${topology}" = "llama2_7b_gptq_int4" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" elif [ "${topology}" = "llama2_7b_gptq_int4_dq_bnb" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "llama2_7b_gptq_int4_dq_ggml" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then model_name_or_path="EleutherAI/gpt-j-6b" - extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search" elif [ "${topology}" = "gpt_j_woq_rtn_int4_dq_bnb" ]; then - model_name_or_path="EleutherAI/gpt-j-6b"\ - extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search" - extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" + model_name_or_path="EleutherAI/gpt-j-6b" elif [ "${topology}" = "gpt_j_woq_rtn_int4_dq_ggml" ]; then - model_name_or_path="EleutherAI/gpt-j-6b"\ - extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search" - extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + model_name_or_path="EleutherAI/gpt-j-6b" elif [ "${topology}" = "gpt_j_woq_gptq_int4" ]; then model_name_or_path="EleutherAI/gpt-j-6b" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" elif [ "${topology}" = "gpt_j_woq_gptq_int4_dq_bnb" ]; then model_name_or_path="EleutherAI/gpt-j-6b" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "gpt_j_woq_gptq_int4_dq_ggml" ]; then model_name_or_path="EleutherAI/gpt-j-6b" - extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "gpt_j_woq_awq_int4" ]; then + model_name_or_path="EleutherAI/gpt-j-6b" + elif [ "${topology}" = "opt_125m_woq_awq_int4" ]; then + model_name_or_path="facebook/opt-125m" + elif [ "${topology}" = "opt_125m_woq_autoround_int4" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --woq_algo AutoRound" + elif [ "${topology}" = "opt_125m_woq_autotune_int4" ]; then + model_name_or_path="facebook/opt-125m" fi - python -u run_clm_no_trainer.py \ - --model ${model_name_or_path} \ - --output_dir ${tuned_checkpoint} \ - --task ${task} \ - --batch_size ${batch_size} \ - ${extra_cmd} ${mode_cmd} + if [[ ${mode} == "accuracy" ]]; then + python -u run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --output_dir ${tuned_checkpoint} \ + --task ${task} \ + --batch_size ${batch_size} \ + ${extra_cmd} ${mode_cmd} + elif [[ ${mode} == "performance" ]]; then + incbench --num_cores_per_instance 4 run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --batch_size ${batch_size} \ + --output_dir ${tuned_checkpoint} \ + ${extra_cmd} ${mode_cmd} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + } main "$@" diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py index 02329bd9e15..51be2900ba7 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py @@ -53,7 +53,7 @@ type=str, help="tasks for accuracy validation") parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model") # ============WeightOnly configs=============== -parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ', 'GPTQ'], +parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ', 'GPTQ', 'AutoRound', 'AutoTune'], help="Weight-only parameter.") parser.add_argument("--woq_bits", type=int, default=8) parser.add_argument("--woq_dtype", type=str, default="int") @@ -62,6 +62,7 @@ parser.add_argument("--woq_scheme", default="sym") parser.add_argument("--woq_use_mse_search", action="store_true") parser.add_argument("--woq_use_full_range", action="store_true") +parser.add_argument("--quant_lm_head", action="store_true", help="whether to quant the lm_head layer in transformers") # =============GPTQ configs==================== parser.add_argument("--gptq_actorder", action="store_true", help="Whether to apply the activation order GPTQ heuristic.") @@ -78,6 +79,35 @@ help='Calibration dataset sequence max length, ' 'this should align with your model config, ' 'and your dataset builder args: args.pad_max_length') +# =============AWQ configs==================== +parser.add_argument("--use_auto_scale", action="store_true", + help="Enables best scales search based on activation distribution.") +parser.add_argument("--use_auto_clip", action="store_true", + help="Enables clip range searchc.") +parser.add_argument("--folding", action="store_true", + help="Allow insert mul before linear when the scale cannot be absorbed by last layer for TEQ/AWQ.") +parser.add_argument('--absorb_layer_dict', type=dict, default={}, + help="The layer dict that scale can be absorbed for TEQ/AWQ.") +# ============AUTOROUND configs============== +parser.add_argument( + "--lr", + type=float, + default=None, + help="learning rate, if None, it will be set to 1.0/iters automatically", +) +parser.add_argument( + "--minmax_lr", + type=float, + default=None, + help="minmax learning rate, if None,it will beset to be the same with lr", +) +parser.add_argument("--autoround_iters", default=200, type=int, help="num iters for autoround calibration.") +parser.add_argument("--autoround_nsamples", default=128, type=int, help="num samples for autoround calibration.") +parser.add_argument( + "--disable_quanted_input", + action="store_true", + help="whether to use the output of quantized block to tune the next block", +) # =============DoubleQuant configs==================== parser.add_argument("--double_quant_type", @@ -196,6 +226,8 @@ def get_user_model(): ) tokenizer = AutoTokenizer.from_pretrained(args.model) user_model = user_model.float() + if args.woq_algo == 'AutoRound': + user_model.to(torch.float32) # Set model's seq_len when GPTQ calibration is enabled. if args.woq_algo == 'GPTQ': @@ -210,6 +242,31 @@ def get_user_model(): user_model.eval() return user_model, tokenizer +def eval_fn(user_model=None): + user_model.eval() + from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser + import time + + samples = args.iters * args.batch_size + eval_args = LMEvalParser( + model="hf", + user_model=user_model, + tokenizer=tokenizer, + batch_size=args.batch_size, + tasks=args.tasks, + limit=samples, + device="hpu" if is_hpex_available() else "cpu", + ) + start = time.time() + results = evaluate(eval_args) + end = time.time() + for task_name in args.tasks.split(","): + if task_name == "wikitext": + acc = results["results"][task_name]["word_perplexity,none"] + else: + acc = results["results"][task_name]["acc,none"] + print("Accuracy: %.5f" % acc) + return acc if args.quantize: # dataset @@ -224,9 +281,25 @@ def get_user_model(): shuffle=False, collate_fn=calib_evaluator.collate_batch, ) + def calib_func(prepared_model): + for i, calib_input in enumerate(calib_dataloader): + if i > args.calib_iters: + break + prepared_model(calib_input[0]) # 3.x api - from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, prepare, convert, quantize + from neural_compressor.torch.quantization import ( + RTNConfig, + GPTQConfig, + AWQConfig, + AutoRoundConfig, + TEQConfig, + TuningConfig, + autotune, + get_woq_tuning_config, + prepare, + convert + ) from neural_compressor.torch.utils import get_double_quant_config_dict weight_sym = True if args.woq_scheme == "sym" else False if args.double_quant_type is not None: @@ -239,6 +312,7 @@ def get_user_model(): # TODO: add group_dim into double quant config? "use_full_range": args.woq_use_full_range, "use_mse_search": args.woq_use_mse_search, + "quant_lm_head": args.quant_lm_head, } ) quant_config = RTNConfig.from_dict(double_quant_config_dict) @@ -256,8 +330,8 @@ def get_user_model(): double_quant_dtype=args.double_quant_dtype, double_quant_use_sym=args.double_quant_use_sym, double_quant_group_size=args.double_quant_group_size, + quant_lm_head=args.quant_lm_head, ) - quant_config.set_local("lm_head", RTNConfig(dtype="fp32")) user_model = prepare(model=user_model, quant_config=quant_config) user_model = convert(model=user_model) elif args.woq_algo == "GPTQ": @@ -288,6 +362,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): "act_order": args.gptq_actorder, "block_size": args.gptq_block_size, "static_groups": args.gptq_static_groups, + "quant_lm_head": args.quant_lm_head, } ) quant_config = GPTQConfig.from_dict(double_quant_config_dict) @@ -307,11 +382,109 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): double_quant_dtype=args.double_quant_dtype, double_quant_use_sym=args.double_quant_use_sym, double_quant_group_size=args.double_quant_group_size, + quant_lm_head=args.quant_lm_head, ) - quant_config.set_local("lm_head", GPTQConfig(dtype="fp32")) user_model = prepare(model=user_model, quant_config=quant_config) run_fn_for_gptq(user_model, dataloader_for_calibration) user_model = convert(user_model) + elif args.woq_algo == "AWQ": + quant_config = AWQConfig( + dtype=args.woq_dtype, + bits=args.woq_bits, + use_sym=weight_sym, + group_size=args.woq_group_size, + group_dim=args.woq_group_dim, + use_auto_scale=args.use_auto_scale, + use_auto_clip=args.use_auto_clip, + folding=args.folding, + absorb_layer_dict=args.absorb_layer_dict, + quant_lm_head=args.quant_lm_head, + ) + example_inputs = torch.ones([1, args.pad_max_length], dtype=torch.long) + run_fn = calib_func + user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(user_model) + user_model = convert(user_model) + elif args.woq_algo == "TEQ": + quant_config = TEQConfig( + dtype=args.woq_dtype, + bits=args.woq_bits, + use_sym=weight_sym, + group_size=args.woq_group_size, + group_dim=args.woq_group_dim, + folding=args.folding, + quant_lm_head=args.quant_lm_head, + ) + example_inputs = torch.ones([1, args.pad_max_length], dtype=torch.long) + run_fn = calib_func + user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(user_model) + user_model = convert(user_model) + elif args.woq_algo == "AutoRound": + quant_config = AutoRoundConfig( + dtype=args.woq_dtype, + bits=args.woq_bits, + use_sym=weight_sym, + group_size=args.woq_group_size, + enable_quanted_input=not args.disable_quanted_input, + lr=args.lr, + minmax_lr=args.minmax_lr, + seqlen=args.pad_max_length, + nsamples=args.autoround_nsamples, + iters=args.autoround_iters, + ) + quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader + dataloader = get_dataloader(tokenizer=tokenizer, + seqlen=args.pad_max_length, + dataset_name=datasets, + seed=args.seed, + bs=args.batch_size, + nsamples=args.autoround_nsamples) + @torch.no_grad() + def run_fn_for_autoround(model, dataloader): + for data in dataloader: + if isinstance(data, tuple) or isinstance(data, list): + model(*data) + elif isinstance(data, dict): + model(**data) + else: + model(data) + run_fn = run_fn_for_autoround + run_args = (dataloader,) + user_model = prepare(model=user_model, quant_config=quant_config) + run_fn(user_model, *run_args) + user_model = convert(user_model) + elif args.woq_algo == "AutoTune": + from utils import DataloaderPreprocessor + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=calib_dataloader, + use_max_length=args.gptq_use_max_length, + max_seq_length=args.gptq_max_seq_length, + ) + dataloader = dataloaderPreprocessor.get_prepared_dataloader() + custom_tune_config = TuningConfig(config_set=get_woq_tuning_config()) + from neural_compressor.torch.algorithms.weight_only.utility import move_input_to_device + from tqdm import tqdm + def run_fn_for_gptq(model, dataloader_for_calibration, *args): + for batch in tqdm(dataloader_for_calibration): + batch = move_input_to_device(batch, device=None) + if isinstance(batch, tuple) or isinstance(batch, list): + model(batch[0]) + elif isinstance(batch, dict): + model(**batch) + else: + model(batch) + return + example_inputs = torch.ones([1, args.pad_max_length], dtype=torch.long) + user_model = autotune( + model=user_model, + tune_config=custom_tune_config, + eval_fn=eval_fn, + run_fn=run_fn_for_gptq, + run_args=(dataloader, True), # run_args should be a tuple, + example_inputs=example_inputs, + ) user_model.save(args.output_dir) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_quant.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_quant.sh index a860712b697..ed4ee705726 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_quant.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_quant.sh @@ -85,6 +85,19 @@ function run_tuning { model_name_or_path="EleutherAI/gpt-j-6b" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "gpt_j_woq_awq_int4" ]; then + model_name_or_path="EleutherAI/gpt-j-6b" + extra_cmd=$extra_cmd" --woq_algo AWQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --calib_iters 128" + extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "opt_125m_woq_awq_int4" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --woq_algo AWQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --calib_iters 128" + elif [ "${topology}" = "opt_125m_woq_autoround_int4" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --woq_algo AutoRound --woq_bits 4 --woq_group_size 128 --woq_scheme asym --autoround_iters 200 --autoround_nsamples 500" + elif [ "${topology}" = "opt_125m_woq_autotune_int4" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --woq_algo AutoTune --woq_bits 4" fi python -u run_clm_no_trainer.py \