Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op statistics dump for woq #1876

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch

from neural_compressor.torch.utils import accelerator, device_synchronize, logger
Expand Down
15 changes: 14 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo
from neural_compressor.torch.utils import (
dump_model_op_stats,
get_quantizer,
is_ipex_imported,
logger,
postprocess_model,
register_algo,
)
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT


Expand Down Expand Up @@ -89,6 +96,7 @@ def rtn_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand Down Expand Up @@ -141,6 +149,7 @@ def gptq_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down Expand Up @@ -361,6 +370,7 @@ def awq_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand Down Expand Up @@ -415,6 +425,7 @@ def teq_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down Expand Up @@ -491,6 +502,7 @@ def autoround_quantize_entry(
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand All @@ -511,6 +523,7 @@ def hqq_entry(
quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
model = quantizer.execute(model, mode=mode)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

return model

Expand Down
99 changes: 99 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, Dict, List, Tuple, Union

import torch
from prettytable import PrettyTable
from typing_extensions import TypeAlias

from neural_compressor.common.utils import LazyImport, Mode, logger
Expand Down Expand Up @@ -163,3 +164,101 @@ def postprocess_model(model, mode, quantizer):
elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
if getattr(model, "quantizer", False):
del model.quantizer


class Statistics: # pragma: no cover
"""The statistics printer."""

def __init__(self, data, header, field_names, output_handle=logger.info):
"""Init a Statistics object.

Args:
data: The statistics data
header: The table header
field_names: The field names
output_handle: The output logging method
"""
self.field_names = field_names
self.header = header
self.data = data
self.output_handle = output_handle
self.tb = PrettyTable(min_table_width=40)

def print_stat(self):
"""Print the statistics."""
valid_field_names = []
for index, value in enumerate(self.field_names):
if index < 2:
valid_field_names.append(value)
continue

if any(i[index] for i in self.data):
valid_field_names.append(value)
self.tb.field_names = valid_field_names
for i in self.data:
tmp_data = []
for index, value in enumerate(i):
if self.field_names[index] in valid_field_names:
tmp_data.append(value)
if any(tmp_data[1:]):
self.tb.add_row(tmp_data)
lines = self.tb.get_string().split("\n")
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
for i in lines:
self.output_handle(i)


def dump_model_op_stats(mode, tune_cfg):
"""This is a function to dump quantizable ops of model to user.

Args:
model (object): input model
tune_cfg (dict): quantization config
Returns:
None
"""
if mode == Mode.PREPARE:
return
res = {}
# collect all dtype info and build empty results with existing op_type
dtype_set = set()
for op, config in tune_cfg.items():
op_type = op[1]
config = config.to_dict()
# import pdb; pdb.set_trace()
if not config["dtype"] == "fp32":
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()

for op, config in tune_cfg.items():
config = config.to_dict()
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg.items():
config = config.to_dict()
if config["dtype"] == "fp32":
res[op_type]["FP32"] += 1
else:
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1

# update stats format for dump.
field_names = ["Op Type", "Total"]
field_names.extend(dtype_list)
output_data = []
for op_type in res.keys():
field_results = [op_type, sum(res[op_type].values())]
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
1 change: 1 addition & 0 deletions requirements_pt.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy < 2.0
peft==0.10.0
prettytable
psutil
py-cpuinfo
pydantic
Expand Down
Loading