diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index ce13990c00f..207dc212dcf 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import torch from neural_compressor.torch.utils import accelerator, device_synchronize, logger diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index cf429c2118f..733e4409b91 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -45,7 +45,14 @@ StaticQuantConfig, TEQConfig, ) -from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo +from neural_compressor.torch.utils import ( + dump_model_op_stats, + get_quantizer, + is_ipex_imported, + logger, + postprocess_model, + register_algo, +) from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT @@ -89,6 +96,7 @@ def rtn_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model @@ -141,6 +149,7 @@ def gptq_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model @@ -361,6 +370,7 @@ def awq_quantize_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model @@ -415,6 +425,7 @@ def teq_quantize_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model @@ -491,6 +502,7 @@ def autoround_quantize_entry( model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model @@ -511,6 +523,7 @@ def hqq_entry( quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping) model = quantizer.execute(model, mode=mode) postprocess_model(model, mode, quantizer) + dump_model_op_stats(mode, configs_mapping) return model diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index b7855d506e6..e1c869dca45 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -16,6 +16,7 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from prettytable import PrettyTable from typing_extensions import TypeAlias from neural_compressor.common.utils import LazyImport, Mode, logger @@ -163,3 +164,101 @@ def postprocess_model(model, mode, quantizer): elif mode == Mode.CONVERT or mode == Mode.QUANTIZE: if getattr(model, "quantizer", False): del model.quantizer + + +class Statistics: # pragma: no cover + """The statistics printer.""" + + def __init__(self, data, header, field_names, output_handle=logger.info): + """Init a Statistics object. + + Args: + data: The statistics data + header: The table header + field_names: The field names + output_handle: The output logging method + """ + self.field_names = field_names + self.header = header + self.data = data + self.output_handle = output_handle + self.tb = PrettyTable(min_table_width=40) + + def print_stat(self): + """Print the statistics.""" + valid_field_names = [] + for index, value in enumerate(self.field_names): + if index < 2: + valid_field_names.append(value) + continue + + if any(i[index] for i in self.data): + valid_field_names.append(value) + self.tb.field_names = valid_field_names + for i in self.data: + tmp_data = [] + for index, value in enumerate(i): + if self.field_names[index] in valid_field_names: + tmp_data.append(value) + if any(tmp_data[1:]): + self.tb.add_row(tmp_data) + lines = self.tb.get_string().split("\n") + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") + for i in lines: + self.output_handle(i) + + +def dump_model_op_stats(mode, tune_cfg): + """This is a function to dump quantizable ops of model to user. + + Args: + model (object): input model + tune_cfg (dict): quantization config + Returns: + None + """ + if mode == Mode.PREPARE: + return + res = {} + # collect all dtype info and build empty results with existing op_type + dtype_set = set() + for op, config in tune_cfg.items(): + op_type = op[1] + config = config.to_dict() + # import pdb; pdb.set_trace() + if not config["dtype"] == "fp32": + num_bits = config["bits"] + group_size = config["group_size"] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + dtype_set.add(dtype_str) + dtype_set.add("FP32") + dtype_list = list(dtype_set) + dtype_list.sort() + + for op, config in tune_cfg.items(): + config = config.to_dict() + op_type = op[1] + if op_type not in res.keys(): + res[op_type] = {dtype: 0 for dtype in dtype_list} + + # fill in results with op_type and dtype + for op, config in tune_cfg.items(): + config = config.to_dict() + if config["dtype"] == "fp32": + res[op_type]["FP32"] += 1 + else: + num_bits = config["bits"] + group_size = config["group_size"] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + res[op_type][dtype_str] += 1 + + # update stats format for dump. + field_names = ["Op Type", "Total"] + field_names.extend(dtype_list) + output_data = [] + for op_type in res.keys(): + field_results = [op_type, sum(res[op_type].values())] + field_results.extend([res[op_type][dtype] for dtype in dtype_list]) + output_data.append(field_results) + + Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat() diff --git a/requirements_pt.txt b/requirements_pt.txt index a164be3d24f..94667b64665 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,5 +1,6 @@ numpy < 2.0 peft==0.10.0 +prettytable psutil py-cpuinfo pydantic