From 5ba9efe51b000d5a75cf4e1c91a4dcc1c0108983 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Sat, 23 Sep 2023 09:38:02 +0800 Subject: [PATCH] Improve op wise coverage for ORT WOQ (#1270) * Enhance ORT WOQ Signed-off-by: Mengni Wang * bug fix Signed-off-by: Mengni Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update onnxrt.py * Update test_weight_only_adaptor.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update weight_only.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update onnxrt.py * Update test_weight_only_adaptor.py * Update test_weight_only_adaptor.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Mengni Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- neural_compressor/adaptor/onnxrt.py | 63 ++--- neural_compressor/adaptor/onnxrt.yaml | 1 - .../adaptor/ox_utils/weight_only.py | 217 +++++++++--------- neural_compressor/model/onnx_model.py | 33 --- .../test_weight_only_adaptor.py | 54 ++++- 5 files changed, 195 insertions(+), 173 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 8ff0673955a..eaac00caf39 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -1663,7 +1663,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True) enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True) - n_blocks = self.recipes.get("awq_args", {}).get("n_blocks", 5) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) model = awq_quantize( model, @@ -1672,7 +1671,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): n_samples=calib_sampling_size, enable_auto_scale=enable_auto_scale, enable_mse_search=enable_mse_search, - n_blocks=n_blocks, ) elif "RTN" in algos: from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize @@ -1684,33 +1682,42 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): return model def _dump_model_op_stats(self, model, tune_cfg): + import re + + fp32_op_list = self.query_handler.get_op_types_by_precision(precision="weight_only_integer") + res = {} - # collect all dtype info and build empty results with existing op_type + for optype in fp32_op_list: + res[optype] = {} + dtype_set = set() - for op, config in tune_cfg["op"].items(): - op_type = op[1] - if not config["weight"]["dtype"] == "fp32": - num_bits = config["weight"]["bits"] - group_size = config["weight"]["group_size"] - dtype_str = "A32W{}G{}".format(num_bits, group_size) - dtype_set.add(dtype_str) - dtype_set.add("FP32") - dtype_list = list(dtype_set) - dtype_list.sort() - for op, config in tune_cfg["op"].items(): - op_type = op[1] - if op_type not in res.keys(): - res[op_type] = {dtype: 0 for dtype in dtype_list} - - # fill in results with op_type and dtype - for op, config in tune_cfg["op"].items(): - if config["weight"]["dtype"] == "fp32": - res[op_type]["FP32"] += 1 + for node in model.nodes(): + if node.op_type == "MatMulWithQuantWeight": + optype = "MatMul" else: - num_bits = config["weight"]["bits"] - group_size = config["weight"]["group_size"] - dtype_str = "A32W{}G{}".format(num_bits, group_size) - res[op_type][dtype_str] += 1 + optype = node.op_type + + if optype not in res: + continue + if re.fullmatch("^.*_Q\d*G\d*", node.input[1]): + search_out = re.search("_Q\d*", node.input[1]) + dtype = "A32W{}G{}".format( + node.input[1][search_out.start() + 2 : search_out.end()], node.input[1][search_out.end() + 1 :] + ) + else: + dtype = "FP32" + dtype_set.add(dtype) + + if dtype in res[optype]: + res[optype][dtype] += 1 + else: + res[optype][dtype] = 1 + + dtype_list = list(dtype_set) + for dtype in dtype_list: + for optype in res.keys(): + if dtype not in res[optype]: + res[optype][dtype] = 0 # update stats format for dump. field_names = ["Op Type", "Total"] @@ -1760,7 +1767,7 @@ def query_fw_capability(self, model): precisions = query.get_precisions() for precision in precisions: - if precision != "weight_only_integer": + if precision not in ["weight_only_integer", "fp32"]: continue # get supported optype for target precision optypes = ( @@ -1785,7 +1792,7 @@ def query_fw_capability(self, model): continue else: op_capability = copy.deepcopy(configs[op]) - op_capability["activation"]["quant_mode"] = "weight_only" + op_capability["activation"]["quant_mode"] = "weight_only" if op not in optype_wise.keys(): optype_wise[op] = [op_capability] elif op_capability not in optype_wise[op]: diff --git a/neural_compressor/adaptor/onnxrt.yaml b/neural_compressor/adaptor/onnxrt.yaml index 36bbf3069d3..a55250e9da9 100644 --- a/neural_compressor/adaptor/onnxrt.yaml +++ b/neural_compressor/adaptor/onnxrt.yaml @@ -30,7 +30,6 @@ 'dtype': ['fp32'] } }, - 'Attention': *cap_weight_only_matmul } int8: &ref_1_6 { 'static': &ref_1_6_static { diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index 00992461537..5f7683341a7 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -102,10 +102,14 @@ def make_matmul_weight_only_node( scale = np.reshape(scale, (-1, k_blocks)).astype("float32") q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q" + str(num_bits), data_type=2, dims=packed.shape, vals=packed.tostring(), raw=True + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + data_type=2, + dims=packed.shape, + vals=packed.tobytes(), + raw=True, ) scale_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tostring(), raw=True + name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True ) input_names = [node.input[0], q_weight_tensor.name, scale_tensor.name] new_inits = [q_weight_tensor, scale_tensor] @@ -113,7 +117,7 @@ def make_matmul_weight_only_node( if zero_point is not None: zero_point = np.reshape(zero_point, (-1, k_blocks)).astype("uint8") zp_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_zp", data_type=2, dims=zero_point.shape, vals=zero_point.tostring(), raw=True + name=node.input[1] + "_zp", data_type=2, dims=zero_point.shape, vals=zero_point.tobytes(), raw=True ) input_names.append(zp_tensor.name) new_inits.append(zp_tensor) @@ -311,7 +315,11 @@ def rtn_quantize( q_weight = np.transpose(q_weight) q_weight = q_weight[: org_w_shape[0], :].astype(weight.dtype) q_weight_tensor = onnx.helper.make_tensor( - node.input[1] + "_Q" + str(num_bits), 1, weight.shape, q_weight.tostring(), raw=True + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + data_type=1, + dims=weight.shape, + vals=q_weight.tobytes(), + raw=True, ) model.add_initializer(q_weight_tensor) node.input[1] = q_weight_tensor.name @@ -404,13 +412,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, best_scale = scales for node in nodes: - if weight_config.get(node.name, {}) == "fp32": - continue - - if node.name in weight_config: - weight_config[node.name]["bits"] = num_bits - weight_config[node.name]["group_size"] = group_size - weight_config[node.name]["scheme"] = scheme + weight_config.setdefault(node.name, {}).update({"bits": num_bits}) + weight_config.setdefault(node.name, {}).update({"group_size": group_size}) + weight_config.setdefault(node.name, {}).update({"scheme": scheme}) init_share_num = model.get_initializer_share_num(node.input[1]) weight_tensor = model.get_initializer(node.input[1]) @@ -419,9 +423,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, tensor = tensor.T * best_scale tensor = (tensor.T).astype("float32") - new_tensor = onnx.helper.make_tensor( - node.input[1] + "_scaled", 1, tensor.shape, tensor.tostring(), raw=True - ) + new_tensor = onnx.helper.make_tensor(node.input[1] + "_scaled", 1, tensor.shape, tensor.tobytes(), raw=True) model.add_initializer(new_tensor) node.input[1] = new_tensor.name @@ -432,9 +434,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, if parent.name in updated_nodes: continue - if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"] and all( - [weight_config.get(node.name, {}) != "fp32" for node in nodes] - ): + if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"] and len( + model.input_name_to_nodes[nodes[0].input[0]] + ) == len(nodes): for idx in [1, 2]: tensor = numpy_helper.to_array( model.get_initializer(parent.input[idx]), os.path.dirname(model.model_path) @@ -447,7 +449,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, elif ( parent.op_type in ["SimplifiedLayerNormalization", "MatMul", "Gemm", "Mul"] and not all([model.get_initializer(inp) is None for inp in parent.input]) - and all([weight_config.get(node.name, {}) != "fp32" for node in nodes]) + and len(model.input_name_to_nodes[nodes[0].input[0]]) == len(nodes) ): # pragma: no cover for inp in parent.input: if model.get_initializer(inp) is not None: @@ -457,8 +459,8 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, updated_nodes.append(parent.name) output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) - elif parent.op_type in ["Conv", "FusedConv"] and all( - [weight_config.get(node.name, {}) != "fp32" for node in nodes] + elif parent.op_type in ["Conv", "FusedConv"] and len(model.input_name_to_nodes[nodes[0].input[0]]) == len( + nodes ): # pragma: no cover tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), os.path.dirname(model.model_path)) new_tensor = tensor / np.reshape(best_scale, (1, -1)) @@ -484,8 +486,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, ) new_added_mul_nodes.append(mul_node) for node in nodes: - if weight_config.get(node.name, {}) != "fp32": - replace_input.append([node, node.input[0], mul_node.output[0]]) + replace_input.append([node, node.input[0], mul_node.output[0]]) updated_nodes.append(parent.name) output_dicts[mul_node.output[0]] = output_dicts[mul_node.input[0]] / np.reshape(best_scale, (1, -1)) @@ -512,20 +513,13 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0) for node in nodes: - if weight_config.get(node.name, {}) == "fp32": - continue - if node.name in weight_config: num_bits = weight_config[node.name]["bits"] group_size = weight_config[node.name]["group_size"] scheme = weight_config[node.name]["scheme"] - org_weight = ( - numpy_helper.to_array(model.get_initializer(node.input[1]), base_dir=os.path.dirname(model.model_path)) - if model.get_initializer(node.input[1]) is not None - else numpy_helper.to_array( - model.get_initializer(node.input[1].split("_Q")[0]), base_dir=os.path.dirname(model.model_path) - ) + org_weight = numpy_helper.to_array( + model.get_initializer(node.input[1]), base_dir=os.path.dirname(model.model_path) ) org_w_shape = org_weight.shape # ic, oc group_size = group_size if group_size != -1 else org_w_shape[0] @@ -606,6 +600,8 @@ def prepare_inputs(model, n_samples, dataloader): if isinstance(data[0], dict): inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) + elif isinstance(data[0], np.ndarray): + inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]])])) else: inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0])])) return inputs, so @@ -621,7 +617,6 @@ def awq_quantize( n_samples=128, enable_auto_scale=True, enable_mse_search=True, - n_blocks=5, ): """Quant the model with Activation-aware Weight quantization(AWQ) method. @@ -645,7 +640,6 @@ def awq_quantize( n_samples (int, optional): calibration sample number. enable_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True. enable_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True. - n_blocks (int, optional): split model into block number to avoid OOM. Returns: model: fake quantized ONNXModel @@ -656,25 +650,21 @@ def awq_quantize( full_ratio = {} if enable_mse_search or enable_mse_search: - absorb_pairs = model.get_absorb_pairs(["MatMul"]) - - n_samples = 1 inputs, so = prepare_inputs(model, n_samples, dataloader) del dataloader org_output = copy.deepcopy(model.model.graph.output) model.remove_tensors_from_outputs([i.name for i in org_output]) - num_block = math.ceil(len(absorb_pairs) / n_blocks) - dump_pairs = {} output_names = [] - for _, nodes in absorb_pairs.items(): - for node in nodes: - if ( - weight_config.get(node.name, {}) != "fp32" - and weight_config.get(node.name, {}).get("algorithm", "AWQ") == "AWQ" - ): - output_names.append(node.input[0]) + for node in model.nodes(): + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "AWQ") == "AWQ" + ): + output_names.append(node.input[0]) + output_names = list(set(output_names)) model.add_tensors_to_outputs(output_names) if model.is_large_model: onnx.save_model( @@ -691,45 +681,49 @@ def awq_quantize( else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=["CPUExecutionProvider"]) ) - for idx, parent in enumerate(absorb_pairs): - if (idx + 1) % num_block == 0 or (idx + 1) == len(absorb_pairs): - dump_pairs[parent] = absorb_pairs[parent] - output_dicts = {} - dump_tensor = list( - set([i.input[0] for nodes in dump_pairs.values() for i in nodes if i.input[0] in output_names]) - ) - if len(dump_tensor) == 0: - continue + for input_name in output_names: + parent = model.output_name_to_node[input_name] + dump_pairs = {parent.name: []} - for inp in inputs: - for output_idx, output in enumerate(session.run(dump_tensor, inp)): - output_dicts.setdefault(dump_tensor[output_idx], []).append(output) - - if enable_auto_scale: - model, output_dicts = apply_awq_scale( - model, - weight_config, - dump_pairs, - output_dicts, - num_bits, - group_size, - scheme, - ) - if enable_mse_search: - ratios = apply_awq_clip( - model, - weight_config, - dump_pairs, - output_dicts, - num_bits, - group_size, - scheme, - ) - del output_dicts - dump_pairs = {} - full_ratio.update(ratios) - else: - dump_pairs[parent] = absorb_pairs[parent] + for node in model.input_name_to_nodes[input_name]: + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "AWQ") == "AWQ" + ): + dump_pairs[parent.name].append(model.get_node(node.name)) + + if len(dump_pairs[parent.name]) == 0: + continue + + output_dicts = {} + for inp in inputs: + output = session.run([input_name], inp) + output_dicts.setdefault(input_name, []).append(output) + + if enable_auto_scale: + model, output_dicts = apply_awq_scale( + model, + weight_config, + dump_pairs, + output_dicts, + num_bits, + group_size, + scheme, + ) + if enable_mse_search: + ratios = apply_awq_clip( + model, + weight_config, + dump_pairs, + output_dicts, + num_bits, + group_size, + scheme, + ) + del output_dicts + del dump_pairs + full_ratio.update(ratios) model.remove_tensors_from_outputs(output_names) model.model.graph.output.MergeFrom(org_output) @@ -927,20 +921,20 @@ def gptq_quantize( check_op_support_status() model = model if isinstance(model, BaseModel) else ONNXModel(model) output_dicts = {} - absorb_pairs = model.get_absorb_pairs(["MatMul"]) inputs, so = prepare_inputs(model, n_samples, dataloader) del dataloader org_output = copy.deepcopy(model.model.graph.output) model.remove_tensors_from_outputs([i.name for i in org_output]) output_names = [] - for _, nodes in absorb_pairs.items(): - for node in nodes: - if ( - weight_config.get(node.name, {}) != "fp32" - and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" - ): - output_names.append(node.input[0]) + for node in model.nodes(): + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + ): + output_names.append(node.input[0]) + output_names = list(set(output_names)) model.add_tensors_to_outputs(output_names) if model.is_large_model: onnx.save_model( @@ -960,32 +954,32 @@ def gptq_quantize( new_nodes = [] remove_nodes = [] - for parent, nodes in absorb_pairs.items(): - node_list = [ - node for node in nodes if node.input[0] in output_names and weight_config.get(node.name, {}) != "fp32" - ] - dump_tensor = list(set([i.input[0] for i in nodes if i in node_list])) - if len(dump_tensor) == 0: - continue - + for input_name in output_names: + node_list = [] weights = [] - for node in nodes: - if node in node_list: + + for node in model.input_name_to_nodes[input_name]: + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + ): weight = numpy_helper.to_array( - model.get_initializer(node.input[1]), os.path.dirname(model.model_path) + model.get_initializer(model.get_node(node.name).input[1]), os.path.dirname(model.model_path) ).copy() if len(weight.shape) != 2: continue + weights.append(weight) + node_list.append(model.get_node(node.name)) + + if len(weights) == 0: + continue Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] nsamples = 0 - for inp in inputs: - output_dicts = {} - for output_idx, output in enumerate(session.run(dump_tensor, inp)): - output_dicts.setdefault(dump_tensor[output_idx], []).append(output) - - inp = output_dicts[dump_tensor[0]][0] + for data in inputs: + inp = session.run([input_name], data)[0] tmp = inp.shape[0] inp = np.reshape(inp, (-1, inp.shape[-1])) Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] @@ -1002,6 +996,7 @@ def gptq_quantize( num_bits = weight_config[node.name]["bits"] group_size = weight_config[node.name]["group_size"] scheme = weight_config[node.name]["scheme"] + group_size = group_size if group_size != -1 else weight.shape[0] q_weight = gptq( weight, @@ -1040,7 +1035,11 @@ def gptq_quantize( model.add_node(q_matmul_node) else: q_weight_tensor = onnx.helper.make_tensor( - node.input[1] + "_Q" + str(num_bits), 1, q_weight.shape, q_weight.tostring(), raw=True + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + data_type=1, + dims=q_weight.shape, + vals=q_weight.tobytes(), + raw=True, ) model.add_initializer(q_weight_tensor) node.input[1] = q_weight_tensor.name diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 3d85a7a9423..18704777949 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -823,39 +823,6 @@ def match_parent( return None - def get_absorb_pairs(self, target_optype): - """Find absorbable nodes based on parent op_type and their own input status. - - Args: - target_optype (list): target absorbable optype. - - Returns: - absorb_pairs (dict): a dict of absorb pairs {parent: list of absorbable children}. - """ - absorbable_optypes = [ - "LayerNormalization", - "BatchNormalization", - "InstanceNormalization", - "Conv", - "SimplifiedLayerNormalization", - "MatMul", - "Gemm", - "Mul", - "FusedConv", - ] - absorb_pairs = {} - for node in self.nodes(): - if node.op_type in target_optype and self.get_initializer(node.input[1]) is not None: - parent = self.get_parent(node, 0) - if ( - parent is None - or parent.op_type not in absorbable_optypes - or self.get_initializer(parent.input[1]) is None - ): - continue - absorb_pairs.setdefault(parent.name, []).append(node) - return absorb_pairs - def match_parent_path( self, node, diff --git a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py index cd5a3d9be57..eb71711900b 100644 --- a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py @@ -74,7 +74,7 @@ def test_RTN_quant(self): }, }, ) - q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + q_model = quantization.fit(self.model, conf) for data, _ in self.dataloader: q_out = Inference(q_model.model, data) org_out = Inference(self.model, data) @@ -151,6 +151,31 @@ def test_AWQ_quant(self): for q, org in zip(q_out, org_out): self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + awq_op_names = [ + i.name for i in q_model.nodes() if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q4G32") + ] + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 32, + "scheme": "asym", + "algorithm": "RTN", + }, + }, + }, + op_name_dict={ + awq_op_names[0]: {"activation": {"dtype": ["fp32"]}, "weight": {"dtype": ["fp32"]}}, + }, + ) + q_model = quantization.fit(self.model, conf) + rtn_op_names = [ + i.name for i in q_model.nodes() if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q4G32") + ] + self.assertTrue(len(rtn_op_names) + 1, len(awq_op_names)) + def test_GPTQ_quant(self): conf = PostTrainingQuantConfig( approach="weight_only", @@ -198,7 +223,7 @@ def test_GPTQ_quant(self): ".*": { # re.match "weight": { "bits": 4, # 1-8 bits - "group_size": -1, # -1 (per-channel) + "group_size": 32, "scheme": "sym", "algorithm": "GPTQ", }, @@ -215,6 +240,31 @@ def test_GPTQ_quant(self): for q, org in zip(q_out, org_out): self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + gptq_op_names = [ + i.name for i in q_model.nodes() if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q4G32") + ] + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 32, + "scheme": "asym", + "algorithm": "RTN", + }, + }, + }, + op_name_dict={ + gptq_op_names[0]: {"activation": {"dtype": ["fp32"]}, "weight": {"dtype": ["fp32"]}}, + }, + ) + q_model = quantization.fit(self.model, conf) + rtn_op_names = [ + i.name for i in q_model.nodes() if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q4G32") + ] + self.assertTrue(len(rtn_op_names) + 1, len(gptq_op_names)) + if __name__ == "__main__": unittest.main()