Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable auto_round format export #2002

Merged
merged 29 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
db16753
enable auto_round format export
WeiweiZhang1 Sep 12, 2024
1eceb6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
26fe175
Update auto_round dependency to commit 5dd16fc34a974a8c2f5a4288ce72e6…
XuehaoSun Sep 12, 2024
2e67cd5
fix docscan issues
WeiweiZhang1 Sep 12, 2024
b99140c
Merge branch 'enable_autoround_format_quantization' of https://github…
WeiweiZhang1 Sep 12, 2024
a7d1431
fixtypos
WeiweiZhang1 Sep 12, 2024
8e78efc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
0adc4ef
fix self.quantization_config
Kaihui-intel Sep 12, 2024
73d8c2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
dc49120
Merge branch 'master' into enable_autoround_format_quantization
xin3he Sep 13, 2024
27b4f43
rm ar ut
Kaihui-intel Sep 13, 2024
46f3c76
fixtypos
WeiweiZhang1 Sep 13, 2024
28e4878
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
8bb25c9
Merge branch 'enable_autoround_format_quantization' of https://github…
Kaihui-intel Sep 13, 2024
c744130
revert ar ut
WeiweiZhang1 Sep 14, 2024
39d66e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
79f44f4
refine UT
WeiweiZhang1 Sep 14, 2024
16a296e
refine UT
WeiweiZhang1 Sep 14, 2024
91f7985
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
01136d7
fix unit test
XuehaoSun Sep 14, 2024
07ae762
against code coverage issue
WeiweiZhang1 Sep 14, 2024
d3c3f39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
461379a
fixtypo
WeiweiZhang1 Sep 14, 2024
7fbf186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
41bfca5
fixtypo
WeiweiZhang1 Sep 14, 2024
7a72f52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
f3bf7fb
fixtypo
WeiweiZhang1 Sep 14, 2024
a280b10
fixtypo
WeiweiZhang1 Sep 14, 2024
7f41ff0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .azure-pipelines/scripts/ut/3x/run_3x_pt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ rm -rf torch/quantization/fp8_quant
LOG_DIR=/neural-compressor/log_dir
mkdir -p ${LOG_DIR}
ut_log_name=${LOG_DIR}/ut_3x_pt.log
pytest --cov="${inc_path}" -vs --disable-warnings --html=report.html --self-contained-html . 2>&1 | tee -a ${ut_log_name}

find . -name "test*.py" | sed "s,\.\/,python -m pytest --cov=\"${inc_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run.sh
cat run.sh
bash run.sh 2>&1 | tee ${ut_log_name}

cp report.html ${LOG_DIR}/

Expand Down
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e
pip install git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f
fi

# test deps
Expand Down
8 changes: 7 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
act_sym: bool = None,
act_dynamic: bool = True,
low_cpu_mem_usage: bool = False,
export_format: str = "itrex",
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.low_cpu_mem_usage = low_cpu_mem_usage
self.export_format = export_format

def prepare(self, model: torch.nn.Module, *args, **kwargs):
"""Prepares a given model for quantization.
Expand Down Expand Up @@ -211,7 +213,11 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
)
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
if "itrex" in self.export_format:
model = pack_model(model, weight_config, device=self.device, inplace=True)
else: # pragma: no cover
model = rounder.save_quantized(output_dir=None, format=self.export_format, device=self.device, inplace=True)

return model


Expand Down
33 changes: 29 additions & 4 deletions neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,32 @@
device_woqlinear_mapping = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear}


def save(model, output_dir="./saved_results"):
def save(model, output_dir="./saved_results", format=LoadFormat.DEFAULT, **kwargs):
"""Save the quantized model and config to the output path.

Args:
model (torch.nn.module): raw fp32 model or prepared model.
output_dir (str, optional): output path to save.
format (str, optional): The format in which to save the model. Options include "default" and "huggingface". Defaults to "default".
kwargs: Additional arguments for specific formats. For example:
- safe_serialization (bool): Whether to use safe serialization when saving (only applicable for 'huggingface' format). Defaults to True.
- tokenizer (Tokenizer, optional): The tokenizer to be saved along with the model (only applicable for 'huggingface' format).
- max_shard_size (str, optional): The maximum size for each shard (only applicable for 'huggingface' format). Defaults to "5GB".
"""
os.makedirs(output_dir, exist_ok=True)
if format == LoadFormat.HUGGINGFACE: # pragma: no cover
config = model.config
quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None
if "backend" in quantization_config and "auto_round" in quantization_config["backend"]:
safe_serialization = kwargs.get("safe_serialization", True)
tokenizer = kwargs.get("tokenizer", None)
max_shard_size = kwargs.get("max_shard_size", "5GB")
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
del model.save
model.save_pretrained(output_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
return

qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
# saving process
Expand Down Expand Up @@ -203,8 +221,15 @@ def load_hf_format_woq_model(self):

# get model class and config
model_class, config = self._get_model_class_and_config()
self.quantization_config = config.quantization_config

self.quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None
if (
"backend" in self.quantization_config and "auto_round" in self.quantization_config["backend"]
): # # pragma: no cover
# load autoround format quantized model
from auto_round import AutoRoundConfig

model = model_class.from_pretrained(self.model_name_or_path)
return model
# get loaded state_dict
self.loaded_state_dict = self._get_loaded_state_dict(config)
self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))
Expand Down Expand Up @@ -400,7 +425,7 @@ def _get_model_class_and_config(self):
trust_remote_code = self.kwargs.pop("trust_remote_code", None)
kwarg_attn_imp = self.kwargs.pop("attn_implementation", None)

config = AutoConfig.from_pretrained(self.model_name_or_path)
config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=trust_remote_code)
# quantization_config = config.quantization_config

if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def autoround_quantize_entry(
scale_dtype = quant_config.scale_dtype
quant_block_list = quant_config.quant_block_list
low_cpu_mem_usage = quant_config.use_layer_wise
export_format = quant_config.export_format

kwargs.pop("example_inputs")

Expand Down Expand Up @@ -636,6 +637,7 @@ def autoround_quantize_entry(
scale_dtype=scale_dtype,
quant_block_list=quant_block_list,
low_cpu_mem_usage=low_cpu_mem_usage,
export_format=export_format,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def __init__(
scale_dtype: str = "fp16",
use_layer_wise: bool = False,
quant_block_list: list = None,
export_format: str = "itrex",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init AUTOROUND weight-only quantization config.
Expand Down Expand Up @@ -973,6 +974,7 @@ def __init__(
have different choices.
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
quant_block_list (list): A list whose elements are list of block's layer names to be quantized.
export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex".
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
Default is DEFAULT_WHITE_LIST.
"""
Expand Down Expand Up @@ -1005,6 +1007,7 @@ def __init__(
self.scale_dtype = scale_dtype
self.use_layer_wise = use_layer_wise
self.quant_block_list = quant_block_list
self.export_format = export_format
self._post_init()

@classmethod
Expand Down
17 changes: 17 additions & 0 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def run_fn(model, dataloader):

@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
class TestAutoRound:
@classmethod
def setup_class(self):
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
Expand All @@ -52,6 +53,7 @@ def setup_class(self):
self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10)
self.label = self.gptj(self.inp)[0]

@classmethod
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

Expand Down Expand Up @@ -159,3 +161,18 @@ def test_conv1d(self):
out2 = q_model(**encoded_input)[0]
assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected."
assert isinstance(q_model.h[0].attn.c_attn, WeightOnlyLinear), "loading compressed model failed."

# def test_autoround_format_export(self):
# from neural_compressor.torch.quantization import load
# from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear
# gpt_j_model = copy.deepcopy(self.gptj)
# quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32", export_format="auto_round:gptq")
# logger.info(f"Test AutoRound with config {quant_config}")
# model = prepare(model=gpt_j_model, quant_config=quant_config)
# run_fn(model, self.dataloader)
# q_model = convert(model)
# out = q_model(self.inp)[0]
# assert torch.allclose(out, self.label, atol=1e-1)
# assert isinstance(q_model.transformer.h[0].attn.k_proj, QuantLinear), "packing model failed."
# q_model.save(output_dir="saved_results_tiny-random-GPTJForCausalLM", format="huggingface")
# loaded_model = load("saved_results_tiny-random-GPTJForCausalLM", format="huggingface", trust_remote_code=True)
2 changes: 1 addition & 1 deletion test/3x/torch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
auto_round @ git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e
auto_round @ git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f
expecttest
intel_extension_for_pytorch
numpy
Expand Down
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
accelerate==0.21.0
auto-round @ git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e
auto-round @ git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f
dynast==1.6.0rc1
horovod
intel-extension-for-pytorch
Expand Down
Loading