Skip to content

Commit

Permalink
Add llm examples to SmoothQuant 3.x API (#1685)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Apr 16, 2024
1 parent 3bb284c commit 137fa3a
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 6 deletions.
14 changes: 14 additions & 0 deletions examples/.config/model_params_pytorch.json
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,13 @@
"batch_size": 1,
"main_script": "run_clm_no_trainer.py"
},
"llama2_7b_ipex":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"llama2_7b_ipex_sq":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
Expand Down Expand Up @@ -548,6 +555,13 @@
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"gpt_j_ipex":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"gpt_j_ipex_sq":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ Here is how to run the scripts:
### GPT-J-6b

#### Quantization
```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model EleutherAI/gpt-j-6B \
--quantize \
--sq \
--alpha 1.0 \
--ipex \
--output_dir "saved_results"
```
**Notes**: Smooth quantization here is based on torch.jit. Without past key value in example_inputs, the quantized model cannot be used for text-generation.

```bash
# "--approach weight_only" is used to enable weight only quantization.
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
Expand Down Expand Up @@ -62,6 +74,15 @@ python run_clm_no_trainer.py \
#### Quantization

```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model facebook/opt-125m \
--quantize \
--sq \
--alpha 0.5 \
--ipex \
--output_dir "saved_results"

# "--approach weight_only" is used to enable weight only quantization.
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
Expand Down Expand Up @@ -95,10 +116,20 @@ python run_clm_no_trainer.py \
--double_quant_type "BNB_NF4"
```

### LLAMA2-7b/13b/30b
### LLAMA2-7b/13b/70b
>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy.
#### Quantization

```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model meta-llama/Llama-2-7b-hf \
--quantize \
--sq \
--alpha 0.8 \
--ipex \
--output_dir "saved_results"

# "--approach weight_only" is used to enable weight only quantization.
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,62 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, )
)
else:
# TODO: smooth quant
print("Only support WeightOnlyQuant now")
if args.sq:
from neural_compressor.torch.quantization import SmoothQuantConfig, quantize

# alpha can be a float number of a list of float number.
args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha)
if re.search("falcon", user_model.config.model_type):
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False)
else:
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)

if re.search("gpt", user_model.config.model_type):
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
else:
from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig

quant_config = get_default_static_config()
if re.search("gpt", user_model.config.model_type):
quant_config.set_local("add", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))

from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
from tqdm import tqdm
def run_fn(model):
for batch in tqdm(calib_dataloader):
batch = move_input_to_device(batch, device=None)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
except ValueError:
pass
return

from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)
user_model.save(args.output_dir)

if args.int8 or args.int8_bf16_mixed:
print("load int8 model")

from neural_compressor.torch.algorithms.static_quant import load

if args.ipex:
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
else:
# TODO: WOQ save&load
print("Int8 model loading does not support WeightOnlyQuant now.")
pass
else:
user_model, _ = get_user_model()


if args.accuracy:
user_model.eval()
Expand Down Expand Up @@ -382,4 +435,4 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
print('Batch size = %d' % args.batch_size)
print('Batch size = %d' % args.batch_size)
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length --gptq_percdamp 0.1 --gptq_actorder"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "opt_125m_ipex" ]; then
model_name_or_path="facebook/opt-125m"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "opt_125m_ipex_sq" ]; then
model_name_or_path="facebook/opt-125m"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5"
elif [ "${topology}" = "llama2_7b_gptq_int4" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
approach="weight_only"
Expand All @@ -70,6 +76,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "llama2_7b_ipex" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "llama2_7b_ipex_sq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8"
elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
approach="weight_only"
Expand Down Expand Up @@ -98,6 +110,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "gpt_j_ipex" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "gpt_j_ipex_sq" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
fi

python -u run_clm_no_trainer.py \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import random
import torch
from collections import UserDict
from packaging.version import Version
from neural_compressor.common import logger
from neural_compressor.torch.utils import get_torch_version

class DataloaderPreprocessor:
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None:
Expand Down Expand Up @@ -143,4 +146,48 @@ def obtain_first_n_samples_fulllength(self, seed=0):
logger.warning(
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
)
)


def get_example_inputs(model, dataloader):
version = get_torch_version()
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device

# Suggest set dataloader like calib_dataloader
if dataloader is None:
return None
device = next(model.parameters()).device
try:
for idx, (input, label) in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, (list, tuple)):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
except Exception as e: # pragma: no cover
for idx, input in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, list) or isinstance(input, tuple):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
if idx == 0:
assert False, "Please checkout the example_inputs format."
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ python run_clm_no_trainer.py \
# to validate int8 model generated with `--sq`, please remove "--approach weight_only".
# to validate the int8 model quantized with ipex, please include "--ipex".
```
### LLAMA2-7b/13b/30b
### LLAMA2-7b/13b/70b
>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy.
#### Quantization

Expand Down

0 comments on commit 137fa3a

Please sign in to comment.