Skip to content

Commit

Permalink
Improve op wise coverage for ORT WOQ (#1270)
Browse files Browse the repository at this point in the history
* Enhance ORT WOQ

Signed-off-by: Mengni Wang <[email protected]>

* bug fix

Signed-off-by: Mengni Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update onnxrt.py

* Update test_weight_only_adaptor.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update weight_only.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update onnxrt.py

* Update test_weight_only_adaptor.py

* Update test_weight_only_adaptor.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mengni Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mengniwang95 and pre-commit-ci[bot] committed Nov 10, 2023
1 parent 35f9461 commit 5ba9efe
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 173 deletions.
63 changes: 35 additions & 28 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):

enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
n_blocks = self.recipes.get("awq_args", {}).get("n_blocks", 5)
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
model = awq_quantize(
model,
Expand All @@ -1672,7 +1671,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
n_samples=calib_sampling_size,
enable_auto_scale=enable_auto_scale,
enable_mse_search=enable_mse_search,
n_blocks=n_blocks,
)
elif "RTN" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
Expand All @@ -1684,33 +1682,42 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
return model

def _dump_model_op_stats(self, model, tune_cfg):
import re

fp32_op_list = self.query_handler.get_op_types_by_precision(precision="weight_only_integer")

res = {}
# collect all dtype info and build empty results with existing op_type
for optype in fp32_op_list:
res[optype] = {}

dtype_set = set()
for op, config in tune_cfg["op"].items():
op_type = op[1]
if not config["weight"]["dtype"] == "fp32":
num_bits = config["weight"]["bits"]
group_size = config["weight"]["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()
for op, config in tune_cfg["op"].items():
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg["op"].items():
if config["weight"]["dtype"] == "fp32":
res[op_type]["FP32"] += 1
for node in model.nodes():
if node.op_type == "MatMulWithQuantWeight":
optype = "MatMul"
else:
num_bits = config["weight"]["bits"]
group_size = config["weight"]["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1
optype = node.op_type

if optype not in res:
continue
if re.fullmatch("^.*_Q\d*G\d*", node.input[1]):
search_out = re.search("_Q\d*", node.input[1])
dtype = "A32W{}G{}".format(
node.input[1][search_out.start() + 2 : search_out.end()], node.input[1][search_out.end() + 1 :]
)
else:
dtype = "FP32"
dtype_set.add(dtype)

if dtype in res[optype]:
res[optype][dtype] += 1
else:
res[optype][dtype] = 1

dtype_list = list(dtype_set)
for dtype in dtype_list:
for optype in res.keys():
if dtype not in res[optype]:
res[optype][dtype] = 0

# update stats format for dump.
field_names = ["Op Type", "Total"]
Expand Down Expand Up @@ -1760,7 +1767,7 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()

for precision in precisions:
if precision != "weight_only_integer":
if precision not in ["weight_only_integer", "fp32"]:
continue
# get supported optype for target precision
optypes = (
Expand All @@ -1785,7 +1792,7 @@ def query_fw_capability(self, model):
continue
else:
op_capability = copy.deepcopy(configs[op])
op_capability["activation"]["quant_mode"] = "weight_only"
op_capability["activation"]["quant_mode"] = "weight_only"
if op not in optype_wise.keys():
optype_wise[op] = [op_capability]
elif op_capability not in optype_wise[op]:
Expand Down
1 change: 0 additions & 1 deletion neural_compressor/adaptor/onnxrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'dtype': ['fp32']
}
},
'Attention': *cap_weight_only_matmul
}
int8: &ref_1_6 {
'static': &ref_1_6_static {
Expand Down
Loading

0 comments on commit 5ba9efe

Please sign in to comment.