diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh index fba15ce6c4e..8489a218b79 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh @@ -21,7 +21,10 @@ rm -rf torch/quantization/fp8_quant LOG_DIR=/neural-compressor/log_dir mkdir -p ${LOG_DIR} ut_log_name=${LOG_DIR}/ut_3x_pt.log -pytest --cov="${inc_path}" -vs --disable-warnings --html=report.html --self-contained-html . 2>&1 | tee -a ${ut_log_name} + +find . -name "test*.py" | sed "s,\.\/,python -m pytest --cov=\"${inc_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run.sh +cat run.sh +bash run.sh 2>&1 | tee ${ut_log_name} cp report.html ${LOG_DIR}/ diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index 3715c485631..fadd60240da 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then fi if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then - pip install git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e + pip install git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f fi # test deps diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 9931a9e87b3..d806afca1fc 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -61,6 +61,7 @@ def __init__( act_sym: bool = None, act_dynamic: bool = True, low_cpu_mem_usage: bool = False, + export_format: str = "itrex", **kwargs, ): """Init a AutQRoundQuantizer object. @@ -152,6 +153,7 @@ def __init__( self.act_sym = act_sym self.act_dynamic = act_dynamic self.low_cpu_mem_usage = low_cpu_mem_usage + self.export_format = export_format def prepare(self, model: torch.nn.Module, *args, **kwargs): """Prepares a given model for quantization. @@ -211,7 +213,11 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): ) model, weight_config = rounder.quantize() model.autoround_config = weight_config - model = pack_model(model, weight_config, device=self.device, inplace=True) + if "itrex" in self.export_format: + model = pack_model(model, weight_config, device=self.device, inplace=True) + else: # pragma: no cover + model = rounder.save_quantized(output_dir=None, format=self.export_format, device=self.device, inplace=True) + return model diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index feb4b907b7e..8d1259cad00 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -40,14 +40,32 @@ device_woqlinear_mapping = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear} -def save(model, output_dir="./saved_results"): +def save(model, output_dir="./saved_results", format=LoadFormat.DEFAULT, **kwargs): """Save the quantized model and config to the output path. Args: model (torch.nn.module): raw fp32 model or prepared model. output_dir (str, optional): output path to save. + format (str, optional): The format in which to save the model. Options include "default" and "huggingface". Defaults to "default". + kwargs: Additional arguments for specific formats. For example: + - safe_serialization (bool): Whether to use safe serialization when saving (only applicable for 'huggingface' format). Defaults to True. + - tokenizer (Tokenizer, optional): The tokenizer to be saved along with the model (only applicable for 'huggingface' format). + - max_shard_size (str, optional): The maximum size for each shard (only applicable for 'huggingface' format). Defaults to "5GB". """ os.makedirs(output_dir, exist_ok=True) + if format == LoadFormat.HUGGINGFACE: # pragma: no cover + config = model.config + quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None + if "backend" in quantization_config and "auto_round" in quantization_config["backend"]: + safe_serialization = kwargs.get("safe_serialization", True) + tokenizer = kwargs.get("tokenizer", None) + max_shard_size = kwargs.get("max_shard_size", "5GB") + if tokenizer is not None: + tokenizer.save_pretrained(output_dir) + del model.save + model.save_pretrained(output_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + return + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) # saving process @@ -203,8 +221,15 @@ def load_hf_format_woq_model(self): # get model class and config model_class, config = self._get_model_class_and_config() - self.quantization_config = config.quantization_config - + self.quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None + if ( + "backend" in self.quantization_config and "auto_round" in self.quantization_config["backend"] + ): # # pragma: no cover + # load autoround format quantized model + from auto_round import AutoRoundConfig + + model = model_class.from_pretrained(self.model_name_or_path) + return model # get loaded state_dict self.loaded_state_dict = self._get_loaded_state_dict(config) self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys())) @@ -400,7 +425,7 @@ def _get_model_class_and_config(self): trust_remote_code = self.kwargs.pop("trust_remote_code", None) kwarg_attn_imp = self.kwargs.pop("attn_implementation", None) - config = AutoConfig.from_pretrained(self.model_name_or_path) + config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=trust_remote_code) # quantization_config = config.quantization_config if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 3a009d1aa65..1ce289921c8 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -609,6 +609,7 @@ def autoround_quantize_entry( scale_dtype = quant_config.scale_dtype quant_block_list = quant_config.quant_block_list low_cpu_mem_usage = quant_config.use_layer_wise + export_format = quant_config.export_format kwargs.pop("example_inputs") @@ -636,6 +637,7 @@ def autoround_quantize_entry( scale_dtype=scale_dtype, quant_block_list=quant_block_list, low_cpu_mem_usage=low_cpu_mem_usage, + export_format=export_format, ) model = quantizer.execute(model=model, mode=mode, *args, **kwargs) model.qconfig = configs_mapping diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index c7b19683882..cb3b1758529 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -939,6 +939,7 @@ def __init__( scale_dtype: str = "fp16", use_layer_wise: bool = False, quant_block_list: list = None, + export_format: str = "itrex", white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -973,6 +974,7 @@ def __init__( have different choices. use_layer_wise (bool): Enables quantize model per layer. Defaults to False. quant_block_list (list): A list whose elements are list of block's layer names to be quantized. + export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex". white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types. Default is DEFAULT_WHITE_LIST. """ @@ -1005,6 +1007,7 @@ def __init__( self.scale_dtype = scale_dtype self.use_layer_wise = use_layer_wise self.quant_block_list = quant_block_list + self.export_format = export_format self._post_init() @classmethod diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 88cae7e9384..8a3942e3f98 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -40,6 +40,7 @@ def run_fn(model, dataloader): @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") class TestAutoRound: + @classmethod def setup_class(self): self.gptj = transformers.AutoModelForCausalLM.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", @@ -52,6 +53,7 @@ def setup_class(self): self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) self.label = self.gptj(self.inp)[0] + @classmethod def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) @@ -159,3 +161,18 @@ def test_conv1d(self): out2 = q_model(**encoded_input)[0] assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." assert isinstance(q_model.h[0].attn.c_attn, WeightOnlyLinear), "loading compressed model failed." + + # def test_autoround_format_export(self): + # from neural_compressor.torch.quantization import load + # from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear + # gpt_j_model = copy.deepcopy(self.gptj) + # quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32", export_format="auto_round:gptq") + # logger.info(f"Test AutoRound with config {quant_config}") + # model = prepare(model=gpt_j_model, quant_config=quant_config) + # run_fn(model, self.dataloader) + # q_model = convert(model) + # out = q_model(self.inp)[0] + # assert torch.allclose(out, self.label, atol=1e-1) + # assert isinstance(q_model.transformer.h[0].attn.k_proj, QuantLinear), "packing model failed." + # q_model.save(output_dir="saved_results_tiny-random-GPTJForCausalLM", format="huggingface") + # loaded_model = load("saved_results_tiny-random-GPTJForCausalLM", format="huggingface", trust_remote_code=True) diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index c17e22d6f77..d2167904cac 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,4 +1,4 @@ -auto_round @ git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e +auto_round @ git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f expecttest intel_extension_for_pytorch numpy diff --git a/test/requirements.txt b/test/requirements.txt index 1999f21e668..4d2908986dd 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,6 +1,6 @@ --find-links https://download.pytorch.org/whl/torch_stable.html accelerate==0.21.0 -auto-round @ git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e +auto-round @ git+https://github.com/intel/auto-round.git@5dd16fc34a974a8c2f5a4288ce72e61ec3b1410f dynast==1.6.0rc1 horovod intel-extension-for-pytorch