Skip to content

Commit

Permalink
Add op statistics dump for woq (#1876)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Jun 18, 2024
1 parent 5a0374e commit 503d9ef
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
2 changes: 0 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch

from neural_compressor.torch.utils import accelerator, device_synchronize, logger
Expand Down
15 changes: 14 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo
from neural_compressor.torch.utils import (
dump_model_op_stats,
get_quantizer,
is_ipex_imported,
logger,
postprocess_model,
register_algo,
)
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT


Expand Down Expand Up @@ -89,6 +96,7 @@ def rtn_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand Down Expand Up @@ -141,6 +149,7 @@ def gptq_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down Expand Up @@ -361,6 +370,7 @@ def awq_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand Down Expand Up @@ -415,6 +425,7 @@ def teq_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down Expand Up @@ -491,6 +502,7 @@ def autoround_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand All @@ -511,6 +523,7 @@ def hqq_entry(
quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
model = quantizer.execute(model, mode=mode)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down
99 changes: 99 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, Dict, List, Tuple, Union

import torch
from prettytable import PrettyTable
from typing_extensions import TypeAlias

from neural_compressor.common.utils import LazyImport, Mode, logger
Expand Down Expand Up @@ -163,3 +164,101 @@ def postprocess_model(model, mode, quantizer):
elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
if getattr(model, "quantizer", False):
del model.quantizer


class Statistics: # pragma: no cover
"""The statistics printer."""

def __init__(self, data, header, field_names, output_handle=logger.info):
"""Init a Statistics object.
Args:
data: The statistics data
header: The table header
field_names: The field names
output_handle: The output logging method
"""
self.field_names = field_names
self.header = header
self.data = data
self.output_handle = output_handle
self.tb = PrettyTable(min_table_width=40)

def print_stat(self):
"""Print the statistics."""
valid_field_names = []
for index, value in enumerate(self.field_names):
if index < 2:
valid_field_names.append(value)
continue

if any(i[index] for i in self.data):
valid_field_names.append(value)
self.tb.field_names = valid_field_names
for i in self.data:
tmp_data = []
for index, value in enumerate(i):
if self.field_names[index] in valid_field_names:
tmp_data.append(value)
if any(tmp_data[1:]):
self.tb.add_row(tmp_data)
lines = self.tb.get_string().split("\n")
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
for i in lines:
self.output_handle(i)


def dump_model_op_stats(mode, tune_cfg):
"""This is a function to dump quantizable ops of model to user.
Args:
model (object): input model
tune_cfg (dict): quantization config
Returns:
None
"""
if mode == Mode.PREPARE:
return
res = {}
# collect all dtype info and build empty results with existing op_type
dtype_set = set()
for op, config in tune_cfg.items():
op_type = op[1]
config = config.to_dict()
# import pdb; pdb.set_trace()
if not config["dtype"] == "fp32":
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()

for op, config in tune_cfg.items():
config = config.to_dict()
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg.items():
config = config.to_dict()
if config["dtype"] == "fp32":
res[op_type]["FP32"] += 1
else:
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1

# update stats format for dump.
field_names = ["Op Type", "Total"]
field_names.extend(dtype_list)
output_data = []
for op_type in res.keys():
field_results = [op_type, sum(res[op_type].values())]
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
1 change: 1 addition & 0 deletions requirements_pt.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy < 2.0
peft==0.10.0
prettytable
psutil
py-cpuinfo
pydantic
Expand Down

0 comments on commit 503d9ef

Please sign in to comment.