Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm examples to SmoothQuant 3.x API #1685

Merged
merged 22 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ Here is how to run the scripts:
### GPT-J-6b

#### Quantization
```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model EleutherAI/gpt-j-6B \
--quantize \
--sq \
--alpha 1.0 \
--ipex \
--output_dir "saved_results"
```
**Notes**: Smooth quantization here is based on torch.jit. Without past key value in example_inputs, the quantized model cannot be used for text-generation.

```bash
# "--approach weight_only" is used to enable weight only quantization.
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
Expand Down Expand Up @@ -62,6 +74,15 @@ python run_clm_no_trainer.py \
#### Quantization

```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model facebook/opt-125m \
--quantize \
--sq \
--alpha 0.5 \
--ipex \
--output_dir "saved_results"

# "--approach weight_only" is used to enable weight only quantization.
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
Expand Down Expand Up @@ -96,9 +117,19 @@ python run_clm_no_trainer.py \
```

### LLAMA2-7b/13b/30b
>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy.
#### Quantization

```bash
# "--sq" is used to enable smooth quant
python run_clm_no_trainer.py \
--model meta-llama/Llama-2-7b-hf \
--quantize \
--sq \
--alpha 0.8 \
--ipex \
--output_dir "saved_results"

# "--approach weight_only" is used to enable weight only quantization.
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,60 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, )
)
else:
# TODO: smooth quant
print("Only support WeightOnlyQuant now")
if args.sq:
from neural_compressor.torch.quantization import SmoothQuantConfig, quantize

# alpha can be a float number of a list of float number.
args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha)
if re.search("falcon", user_model.config.model_type):
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False)
else:
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)
else:
from neural_compressor.torch.quantization import quantize, get_default_static_config

quant_config = get_default_static_config()

from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
from tqdm import tqdm
def run_fn(model):
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
for batch in tqdm(calib_dataloader):
batch = move_input_to_device(batch, device=None)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
except ValueError:
pass
return

if re.search("gpt", user_model.config.model_type):
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
from utils import get_example_inputs
example_inputs = get_example_inputs(user_model, calib_dataloader)
user_model = quantize(
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
)

user_model.save(args.output_dir)

if args.int8 or args.int8_bf16_mixed:
print("load int8 model")

from neural_compressor.torch.algorithms.static_quant import load

if args.ipex:
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
else:
# TODO: WOQ save&load
print("Int8 model loading does not support WeightOnlyQuant now.")
pass
else:
user_model, _ = get_user_model()


if args.accuracy:
user_model.eval()
Expand Down Expand Up @@ -382,4 +433,4 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
print('Batch size = %d' % args.batch_size)
print('Batch size = %d' % args.batch_size)
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length --gptq_percdamp 0.1 --gptq_actorder"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "opt_125m_ipex" ]; then
model_name_or_path="facebook/opt-125m"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "opt_125m_sq" ]; then
model_name_or_path="facebook/opt-125m"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5"
elif [ "${topology}" = "llama2_7b_gptq_int4" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
approach="weight_only"
Expand All @@ -70,6 +76,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "llama2_7b_ipex" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "llama2_7b_sq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8"
elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
approach="weight_only"
Expand Down Expand Up @@ -98,6 +110,12 @@ function run_tuning {
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
elif [ "${topology}" = "gpt_j_ipex" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "gpt_j_sq" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
fi

python -u run_clm_no_trainer.py \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import random
import torch
from collections import UserDict
from packaging.version import Version
from neural_compressor.common import logger
from neural_compressor.torch.utils import get_torch_version

class DataloaderPreprocessor:
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None:
Expand Down Expand Up @@ -143,4 +146,48 @@ def obtain_first_n_samples_fulllength(self, seed=0):
logger.warning(
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
)
)


def get_example_inputs(model, dataloader):
version = get_torch_version()
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device

# Suggest set dataloader like calib_dataloader
if dataloader is None:
return None
device = next(model.parameters()).device
try:
for idx, (input, label) in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, (list, tuple)):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
except Exception as e: # pragma: no cover
for idx, input in enumerate(dataloader):
input = move_input_to_device(input, device)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
if "label" in input.keys():
input.pop("label")
if version.release <= Version("2.0.1").release:
return tuple(input.values())
else:
return dict(input)
if isinstance(input, list) or isinstance(input, tuple):
return tuple(input)
if isinstance(input, torch.Tensor):
return input
break
if idx == 0:
assert False, "Please checkout the example_inputs format."
7 changes: 5 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
# pylint:disable=import-error

import copy
from collections import OrderedDict
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -818,7 +819,8 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
ori_model = copy.deepcopy(model)
violetch24 marked this conversation as resolved.
Show resolved Hide resolved
model_info, _, _, _ = get_quantizable_ops_recursively(ori_model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down Expand Up @@ -923,7 +925,8 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively

model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
ori_model = copy.deepcopy(model)
model_info, _, _, _ = get_quantizable_ops_recursively(ori_model, example_inputs=example_inputs)
return model_info

@classmethod
Expand Down
Loading