Skip to content

Commit

Permalink
Improve UT Branch Coverage for TF 3x (#1867)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel authored Jun 14, 2024
1 parent b99a79d commit a141512
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tensorflow==2.11.0
tensorflow
neural-compressor
44 changes: 22 additions & 22 deletions neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, framework_specific_info):
cfg_yaml_name = "{}.yaml".format(self.__class__.__name__[: -len("Adaptor")].lower())
self.itex_mode = self.backend == "itex" or cfg_yaml_name == "tensorflow_itex.yaml"

if self.itex_mode:
if self.itex_mode: # pragma: no cover
self._check_itex()

self.query_handler = TensorflowQuery(
Expand All @@ -109,7 +109,7 @@ def __init__(self, framework_specific_info):

self._last_dequantize_ops = None

def _check_itex(self):
def _check_itex(self): # pragma: no cover
try:
import intel_extension_for_tensorflow
except:
Expand All @@ -133,7 +133,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):

invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names]

for op_name in invalid_op_names:
for op_name in invalid_op_names: # pragma: no cover
self.quantize_config["op_wise_config"].pop(op_name)

for each_op_info in tuning_cfg["op"]:
Expand All @@ -144,7 +144,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):
self.quantize_config["op_wise_config"].pop(op_name)
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32":
fp32_ops.append(op_name)
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16":
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": # pragma: no cover
bf16_ops.append(op_name)
continue

Expand Down Expand Up @@ -342,7 +342,7 @@ def _dump_model_op_stats(self, model_graphdef):
res[origin_op_type]["INT8"] += 1

if i.op in fp32_op_list:
if "T" not in i.attr and i.op != "Cast":
if "T" not in i.attr and i.op != "Cast": # pragma: no cover
continue
if i.op == "Cast":
if i.attr["DstT"].type == dtypes.bfloat16:
Expand Down Expand Up @@ -432,7 +432,7 @@ def _query_quantizable_ops(self, matched_nodes):
) and len(first_conv_or_matmul_node) == 0:
first_conv_or_matmul_node.append((node_name, self.unify_op_type_mapping[node_op]))
self.recipes_ops["first_conv_or_matmul_quantization"] = first_conv_or_matmul_node
if exclude_first_quantizable_op and (
if exclude_first_quantizable_op and ( # pragma: no cover
self.unify_op_type_mapping[node_op].find("conv2d") != -1
or self.unify_op_type_mapping[node_op].find("matmul") != -1
):
Expand Down Expand Up @@ -493,7 +493,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
concat_nodes = g.query_fusion_pattern_nodes([["ConcatV2"]])
for i in concat_nodes:
concat_node_name = i[0]
if concat_node_name not in target_concat_nodes:
if concat_node_name not in target_concat_nodes: # pragma: no cover
continue
input_positive_status = []
for index in range(graph_info[concat_node_name].node.attr["N"].i):
Expand All @@ -507,7 +507,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
else:
positive_input = g.has_positive_input(each_input_node.name)
input_positive_status.append(positive_input)
if not any(input_positive_status):
if not any(input_positive_status): # pragma: no cover
matched_nodes.remove(i)

def _filter_unquantizable_concat_performance_only(self, matched_nodes):
Expand All @@ -522,7 +522,7 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
concat_nodes = g.query_fusion_pattern_nodes([["ConcatV2"]])
for i in concat_nodes:
concat_node_name = i[0]
if concat_node_name not in target_concat_nodes:
if concat_node_name not in target_concat_nodes: # pragma: no cover
continue
input_positive_status = []
control_flow = False
Expand All @@ -531,9 +531,9 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
graph_info[concat_node_name].node.input[index]
)
each_input_node = graph_info[each_input_name].node
if each_input_node.op in ("Switch"):
if each_input_node.op in ("Switch"): # pragma: no cover
control_flow = True
if control_flow:
if control_flow: # pragma: no cover
matched_nodes.remove(i)

def parse_quant_config(self, quant_config, model, calib_iteration):
Expand Down Expand Up @@ -588,7 +588,7 @@ def _query_fw_capability(self, model):

def check_match(patterns, input_pattern):
for i in patterns:
if input_pattern == [i for i in i.replace("+", " ").strip().split(" ") if i]:
if input_pattern == [i for i in i.replace("+", " ").strip().split(" ") if i]: # pragma: no cover
return True
return False

Expand Down Expand Up @@ -641,7 +641,7 @@ def quantize_input(self, model):
"""
scale = None
# quantize input only support tensorflow version > 2.1.0
if version1_lt_version2(tf.version.VERSION, "2.1.0"):
if version1_lt_version2(tf.version.VERSION, "2.1.0"): # pragma: no cover
logger.warning("Quantize input needs tensorflow 2.1.0 and newer.")
return model, scale

Expand Down Expand Up @@ -872,7 +872,7 @@ def precisions(self):
return self._precisions

@precisions.setter
def precisions(self, precisions):
def precisions(self, precisions): # pragma: no cover
"""Set precision."""
if not isinstance(precisions, list):
precisions = [precisions]
Expand All @@ -881,7 +881,7 @@ def precisions(self, precisions):
self._precisions = precisions

@staticmethod
def check_value(name, src, supported_type, supported_value=[]):
def check_value(name, src, supported_type, supported_value=[]): # pragma: no cover
"""Check if the given object is the given supported type and in the given supported value.
Example::
Expand Down Expand Up @@ -946,7 +946,7 @@ def _get_specified_version_cfg(self, data):
config = None

def _compare(version1, version2):
if parse_version(version1) == parse_version(version2):
if parse_version(version1) == parse_version(version2): # pragma: no cover
return 0
elif parse_version(version1) < parse_version(version2):
return -1
Expand Down Expand Up @@ -979,7 +979,7 @@ def _compare(version1, version2):
# convention. Replacing them with dot for version comparison.
sorted_list = [i.replace("-up", ".") for i in sorted_list]
sorted_list = sorted(sorted_list, key=cmp_to_key(_compare), reverse=True)
else:
else: # pragma: no cover
assert isinstance(sorted_list, str)
sorted_list = list(sorted_list.replace("-up", ".").split())
for i in sorted_list:
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def _one_shot_query(self):
def _update_cfg_with_usr_definition(self):
"""Add user defined precision configuration."""
tensorflow_config = TensorFlowConfig()
if tensorflow_config.precisions is not None:
if tensorflow_config.precisions is not None: # pragma: no cover
self.cur_config["precisions"]["names"] = ",".join(tensorflow_config.precisions)

def get_version(self):
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def get_fuse_patterns(self):
elif version1_gte_version2(tf.version.VERSION, "2.1.0"):
patterns["int8"] = tf_int8_pattern_list
patterns["uint8"] = tf_uint8_pattern_list
if self.itex_mode:
if self.itex_mode: # pragma: no cover
patterns["int8"].append("FusedBatchNormV3 + Relu")
patterns["int8"].append("FusedBatchNormV3 + LeakyRelu")
elif version1_eq_version2(tf.version.VERSION, "1.15.0-up3"): # pragma: no cover
Expand Down Expand Up @@ -1340,23 +1340,23 @@ def get_op_types_by_precision(self, precision):
tf.version.VERSION, "1.15.0-up3"
):
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["MatMul", "ConcatV2", "MaxPool", "AvgPool"] # pragma: no cover
if precision == "uint8":
if tf.version.VERSION in spr_base_verions:
return [key for key in self.cur_config["int8"][self.quant_mode].keys() if "Norm" not in key]
if version1_gte_version2(tf.version.VERSION, "2.1.0") or version1_eq_version2(
tf.version.VERSION, "1.15.0-up3"
):
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool", "DepthwiseConv2dNative"]
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"] # pragma: no cover
if precision == "bf16":
if tf.version.VERSION in spr_base_verions:
return self.cur_config[precision]
if version1_gte_version2(tf.version.VERSION, "2.1.0") or version1_eq_version2(
tf.version.VERSION, "1.15.0-up3"
):
return self.cur_config[precision]
return []
return [] # pragma: no cover

def get_mixed_precision_combination(self):
"""Get the valid mixed precisions.
Expand Down
28 changes: 15 additions & 13 deletions neural_compressor/tensorflow/quantization/utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self.scale_info.update({"bf16_ops": self.bf16_ops})
self.scale_info.update({"fp32_ops": self.fp32_ops})

if "backend" in self.model.kwargs:
if "backend" in self.model.kwargs: # pragma: no cover
self._sampling_model = Model(self.model._model, **self.model.kwargs)
else:
self._sampling_model = Model(
Expand Down Expand Up @@ -245,12 +245,12 @@ def _inference(self, model):
output_tensor = model.output_tensor
# TF table initialization: https://github.com/tensorflow/tensorflow/issues/8665
node_names = [node.name for node in sess.graph.as_graph_def().node]
if "init_all_tables" in node_names:
if "init_all_tables" in node_names: # pragma: no cover
init_table_op = sess.graph.get_operation_by_name("init_all_tables")
sess.run(init_table_op)

logger.info("Start sampling on calibration dataset.")
if hasattr(self.data_loader, "__len__") and len(self.data_loader) == 0:
if hasattr(self.data_loader, "__len__") and len(self.data_loader) == 0: # pragma: no cover
feed_dict = {}
_ = (
sess.run(output_tensor, feed_dict)
Expand Down Expand Up @@ -333,7 +333,7 @@ def _inference_llm(self, model):
feed_dict = {}
if len(input_tensor_names) == 1:
feed_dict[input_tensor_names[0]] = inputs
else:
else: # pragma: no cover
assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor"
for i, input_tensor_name in enumerate(input_tensor_names):
feed_dict[input_tensor_name] = inputs[i]
Expand Down Expand Up @@ -365,7 +365,7 @@ def _check_tf_version(self): # pragma: no cover
if version1_gte_version2(tf.version.VERSION, "2.9.0"):
is_supported_version = True

if tf.version.VERSION == "1.15.0-up3":
if tf.version.VERSION == "1.15.0-up3": # pragma: no cover
is_supported_version = True

if tf.version.VERSION in SPR_BASE_VERSIONS:
Expand Down Expand Up @@ -405,7 +405,7 @@ def _check_tf_version(self): # pragma: no cover
)
)

def _check_args(self):
def _check_args(self): # pragma: no cover
"""Check model's arguments."""
if (
self.model.workspace_path
Expand All @@ -429,7 +429,7 @@ def _gen_tmp_filenames(self):
self._tmp_model = self._fp32_model
else:
# to keep temp model
if "backend" in self.model.kwargs:
if "backend" in self.model.kwargs: # pragma: no cover
self._tmp_model = Model(self.model._model, **self.model.kwargs)
else:
self._tmp_model = Model(
Expand Down Expand Up @@ -707,7 +707,7 @@ def _generate_calibration_data(self, tmp_path, output_data, enable_kl_algo=False

if "backend" in self._tmp_model.kwargs:
model = Model(tmp_path, **self._tmp_model.kwargs)
else:
else: # pragma: no cover
model = Model(
tmp_path,
**self._tmp_model.kwargs,
Expand Down Expand Up @@ -755,7 +755,9 @@ def _freeze_requantization_ranges(self, additional_data=None):
self.scale_info.update(quantizev2_min)
self.scale_info.update(requant_min_max)

if "scale_propagation_max_pooling" in self.recipes and self.recipes["scale_propagation_max_pooling"]:
if (
"scale_propagation_max_pooling" in self.recipes and self.recipes["scale_propagation_max_pooling"]
): # pragma: no cover
self._tmp_graph_def = ScaleProPagationTransformer(self._tmp_graph_def).do_transformation()

if debug and not self.new_api:
Expand Down Expand Up @@ -817,7 +819,7 @@ def _fuse_requantize_with_fused_quantized_node(self):

self._tmp_model.graph_def = self._tmp_graph_def

def _post_clean(self):
def _post_clean(self): # pragma: no cover
"""Delete the temporarily files generated during the quantization process.
:return: None
Expand All @@ -840,7 +842,7 @@ def quantize_with_qdq_pattern(self):
self._insert_qdq_pairs()
self._convert_qdq()

except ValueError as e:
except ValueError as e: # pragma: no cover
logger.error("Fail to quantize graph due to {}.".format(str(e)))
self._tmp_model = None
raise
Expand Down Expand Up @@ -885,10 +887,10 @@ def _insert_qdq_pairs(self):
self.itex_mode,
).get_quantized_nodes()

if self.itex_mode:
if self.itex_mode: # pragma: no cover
self.quantized_node_info.extend(self._search_y_pattern_for_itex())

if self._enable_kl_op_names:
if self._enable_kl_op_names: # pragma: no cover
self._get_fp32_print_node_names(self._enable_kl_op_names)
self._generate_calibration_data(self._fp32_logged_model_path, self._fp32_print_data, True)

Expand Down
Loading

0 comments on commit a141512

Please sign in to comment.