From 137fa3add2d8a0688dd0e76bd15e347b588d56a8 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:30:04 +0800 Subject: [PATCH] Add llm examples to SmoothQuant 3.x API (#1685) Signed-off-by: Cheng, Zixuan --- examples/.config/model_params_pytorch.json | 14 +++++ .../quantization/llm/README.md | 33 ++++++++++- .../quantization/llm/run_clm_no_trainer.py | 59 ++++++++++++++++++- .../quantization/llm/run_quant.sh | 18 ++++++ .../quantization/llm/utils.py | 49 ++++++++++++++- .../quantization/llm/README.md | 2 +- 6 files changed, 169 insertions(+), 6 deletions(-) diff --git a/examples/.config/model_params_pytorch.json b/examples/.config/model_params_pytorch.json index 1532affd5c1..aae501364f7 100644 --- a/examples/.config/model_params_pytorch.json +++ b/examples/.config/model_params_pytorch.json @@ -520,6 +520,13 @@ "batch_size": 1, "main_script": "run_clm_no_trainer.py" }, + "llama2_7b_ipex":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, "llama2_7b_ipex_sq":{ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm", "dataset_location": "", @@ -548,6 +555,13 @@ "main_script": "run_clm_no_trainer.py", "batch_size": 1 }, + "gpt_j_ipex":{ + "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm", + "dataset_location": "", + "input_model": "", + "main_script": "run_clm_no_trainer.py", + "batch_size": 1 + }, "gpt_j_ipex_sq":{ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm", "dataset_location": "", diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md index 0c9ac1ff82f..1659ae41e75 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md @@ -21,6 +21,18 @@ Here is how to run the scripts: ### GPT-J-6b #### Quantization +```bash +# "--sq" is used to enable smooth quant +python run_clm_no_trainer.py \ + --model EleutherAI/gpt-j-6B \ + --quantize \ + --sq \ + --alpha 1.0 \ + --ipex \ + --output_dir "saved_results" +``` +**Notes**: Smooth quantization here is based on torch.jit. Without past key value in example_inputs, the quantized model cannot be used for text-generation. + ```bash # "--approach weight_only" is used to enable weight only quantization. # "--woq_algo GPTQ" is used to enable GPTQ algorithms @@ -62,6 +74,15 @@ python run_clm_no_trainer.py \ #### Quantization ```bash +# "--sq" is used to enable smooth quant +python run_clm_no_trainer.py \ + --model facebook/opt-125m \ + --quantize \ + --sq \ + --alpha 0.5 \ + --ipex \ + --output_dir "saved_results" + # "--approach weight_only" is used to enable weight only quantization. # "--woq_algo GPTQ" is used to enable GPTQ algorithms # "--double_quant_type BNB_NF4" is used to enable double quant algorithms @@ -95,10 +116,20 @@ python run_clm_no_trainer.py \ --double_quant_type "BNB_NF4" ``` -### LLAMA2-7b/13b/30b +### LLAMA2-7b/13b/70b +>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy. #### Quantization ```bash +# "--sq" is used to enable smooth quant +python run_clm_no_trainer.py \ + --model meta-llama/Llama-2-7b-hf \ + --quantize \ + --sq \ + --alpha 0.8 \ + --ipex \ + --output_dir "saved_results" + # "--approach weight_only" is used to enable weight only quantization. # "--double_quant_type BNB_NF4" is used to enable double quant algorithms # "--woq_algo GPTQ" is used to enable GPTQ algorithms diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 484857ddd56..613c0277579 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -331,9 +331,62 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, ) ) else: - # TODO: smooth quant - print("Only support WeightOnlyQuant now") + if args.sq: + from neural_compressor.torch.quantization import SmoothQuantConfig, quantize + + # alpha can be a float number of a list of float number. + args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha) + if re.search("falcon", user_model.config.model_type): + quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False) + else: + quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True) + + if re.search("gpt", user_model.config.model_type): + quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32")) + else: + from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig + + quant_config = get_default_static_config() + if re.search("gpt", user_model.config.model_type): + quant_config.set_local("add", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + + from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device + from tqdm import tqdm + def run_fn(model): + for batch in tqdm(calib_dataloader): + batch = move_input_to_device(batch, device=None) + try: + if isinstance(batch, tuple) or isinstance(batch, list): + model(batch[0]) + elif isinstance(batch, dict): + model(**batch) + else: + model(batch) + except ValueError: + pass + return + + from utils import get_example_inputs + example_inputs = get_example_inputs(user_model, calib_dataloader) + user_model = quantize( + model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn + ) + user_model.save(args.output_dir) + +if args.int8 or args.int8_bf16_mixed: + print("load int8 model") + + from neural_compressor.torch.algorithms.static_quant import load + + if args.ipex: + user_model = load(os.path.abspath(os.path.expanduser(args.output_dir))) + else: + # TODO: WOQ save&load + print("Int8 model loading does not support WeightOnlyQuant now.") pass +else: + user_model, _ = get_user_model() + if args.accuracy: user_model.eval() @@ -382,4 +435,4 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): print("Accuracy: %.5f" % acc) print('Throughput: %.3f samples/sec' % (samples / (end - start))) print('Latency: %.3f ms' % ((end - start) * 1000 / samples)) - print('Batch size = %d' % args.batch_size) + print('Batch size = %d' % args.batch_size) \ No newline at end of file diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh index 32d9fba51ef..05f6d15af32 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh @@ -56,6 +56,12 @@ function run_tuning { approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length --gptq_percdamp 0.1 --gptq_actorder" extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "opt_125m_ipex" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --ipex" + elif [ "${topology}" = "opt_125m_ipex_sq" ]; then + model_name_or_path="facebook/opt-125m" + extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5" elif [ "${topology}" = "llama2_7b_gptq_int4" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" approach="weight_only" @@ -70,6 +76,12 @@ function run_tuning { approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "llama2_7b_ipex" ]; then + model_name_or_path="meta-llama/Llama-2-7b-hf" + extra_cmd=$extra_cmd" --ipex" + elif [ "${topology}" = "llama2_7b_ipex_sq" ]; then + model_name_or_path="meta-llama/Llama-2-7b-hf" + extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8" elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then model_name_or_path="EleutherAI/gpt-j-6b" approach="weight_only" @@ -98,6 +110,12 @@ function run_tuning { approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length" extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K" + elif [ "${topology}" = "gpt_j_ipex" ]; then + model_name_or_path="EleutherAI/gpt-j-6b" + extra_cmd=$extra_cmd" --ipex" + elif [ "${topology}" = "gpt_j_ipex_sq" ]; then + model_name_or_path="EleutherAI/gpt-j-6b" + extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0" fi python -u run_clm_no_trainer.py \ diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/utils.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/utils.py index 79d6e0f90df..38083129a65 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/utils.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/utils.py @@ -1,6 +1,9 @@ import random import torch +from collections import UserDict +from packaging.version import Version from neural_compressor.common import logger +from neural_compressor.torch.utils import get_torch_version class DataloaderPreprocessor: def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None: @@ -143,4 +146,48 @@ def obtain_first_n_samples_fulllength(self, seed=0): logger.warning( f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value." - ) \ No newline at end of file + ) + + +def get_example_inputs(model, dataloader): + version = get_torch_version() + from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device + + # Suggest set dataloader like calib_dataloader + if dataloader is None: + return None + device = next(model.parameters()).device + try: + for idx, (input, label) in enumerate(dataloader): + input = move_input_to_device(input, device) + if isinstance(input, (dict, UserDict)): # pragma: no cover + assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0" + if "label" in input.keys(): + input.pop("label") + if version.release <= Version("2.0.1").release: + return tuple(input.values()) + else: + return dict(input) + if isinstance(input, (list, tuple)): + return tuple(input) + if isinstance(input, torch.Tensor): + return input + break + except Exception as e: # pragma: no cover + for idx, input in enumerate(dataloader): + input = move_input_to_device(input, device) + if isinstance(input, (dict, UserDict)): # pragma: no cover + assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0" + if "label" in input.keys(): + input.pop("label") + if version.release <= Version("2.0.1").release: + return tuple(input.values()) + else: + return dict(input) + if isinstance(input, list) or isinstance(input, tuple): + return tuple(input) + if isinstance(input, torch.Tensor): + return input + break + if idx == 0: + assert False, "Please checkout the example_inputs format." diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md index 97240bea87a..8ca1f48ee49 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md @@ -128,7 +128,7 @@ python run_clm_no_trainer.py \ # to validate int8 model generated with `--sq`, please remove "--approach weight_only". # to validate the int8 model quantized with ipex, please include "--ipex". ``` -### LLAMA2-7b/13b/30b +### LLAMA2-7b/13b/70b >Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy. #### Quantization