Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve UT Branch Coverage for TF 3x #1867

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tensorflow==2.11.0
tensorflow
neural-compressor
44 changes: 22 additions & 22 deletions neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, framework_specific_info):
cfg_yaml_name = "{}.yaml".format(self.__class__.__name__[: -len("Adaptor")].lower())
self.itex_mode = self.backend == "itex" or cfg_yaml_name == "tensorflow_itex.yaml"

if self.itex_mode:
if self.itex_mode: # pragma: no cover
self._check_itex()

self.query_handler = TensorflowQuery(
Expand All @@ -109,7 +109,7 @@ def __init__(self, framework_specific_info):

self._last_dequantize_ops = None

def _check_itex(self):
def _check_itex(self): # pragma: no cover
try:
import intel_extension_for_tensorflow
except:
Expand All @@ -133,7 +133,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):

invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names]

for op_name in invalid_op_names:
for op_name in invalid_op_names: # pragma: no cover
self.quantize_config["op_wise_config"].pop(op_name)

for each_op_info in tuning_cfg["op"]:
Expand All @@ -144,7 +144,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):
self.quantize_config["op_wise_config"].pop(op_name)
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32":
fp32_ops.append(op_name)
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16":
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": # pragma: no cover
bf16_ops.append(op_name)
continue

Expand Down Expand Up @@ -342,7 +342,7 @@ def _dump_model_op_stats(self, model_graphdef):
res[origin_op_type]["INT8"] += 1

if i.op in fp32_op_list:
if "T" not in i.attr and i.op != "Cast":
if "T" not in i.attr and i.op != "Cast": # pragma: no cover
continue
if i.op == "Cast":
if i.attr["DstT"].type == dtypes.bfloat16:
Expand Down Expand Up @@ -432,7 +432,7 @@ def _query_quantizable_ops(self, matched_nodes):
) and len(first_conv_or_matmul_node) == 0:
first_conv_or_matmul_node.append((node_name, self.unify_op_type_mapping[node_op]))
self.recipes_ops["first_conv_or_matmul_quantization"] = first_conv_or_matmul_node
if exclude_first_quantizable_op and (
if exclude_first_quantizable_op and ( # pragma: no cover
self.unify_op_type_mapping[node_op].find("conv2d") != -1
or self.unify_op_type_mapping[node_op].find("matmul") != -1
):
Expand Down Expand Up @@ -493,7 +493,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
concat_nodes = g.query_fusion_pattern_nodes([["ConcatV2"]])
for i in concat_nodes:
concat_node_name = i[0]
if concat_node_name not in target_concat_nodes:
if concat_node_name not in target_concat_nodes: # pragma: no cover
continue
input_positive_status = []
for index in range(graph_info[concat_node_name].node.attr["N"].i):
Expand All @@ -507,7 +507,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
else:
positive_input = g.has_positive_input(each_input_node.name)
input_positive_status.append(positive_input)
if not any(input_positive_status):
if not any(input_positive_status): # pragma: no cover
matched_nodes.remove(i)

def _filter_unquantizable_concat_performance_only(self, matched_nodes):
Expand All @@ -522,7 +522,7 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
concat_nodes = g.query_fusion_pattern_nodes([["ConcatV2"]])
for i in concat_nodes:
concat_node_name = i[0]
if concat_node_name not in target_concat_nodes:
if concat_node_name not in target_concat_nodes: # pragma: no cover
continue
input_positive_status = []
control_flow = False
Expand All @@ -531,9 +531,9 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
graph_info[concat_node_name].node.input[index]
)
each_input_node = graph_info[each_input_name].node
if each_input_node.op in ("Switch"):
if each_input_node.op in ("Switch"): # pragma: no cover
control_flow = True
if control_flow:
if control_flow: # pragma: no cover
matched_nodes.remove(i)

def parse_quant_config(self, quant_config, model, calib_iteration):
Expand Down Expand Up @@ -588,7 +588,7 @@ def _query_fw_capability(self, model):

def check_match(patterns, input_pattern):
for i in patterns:
if input_pattern == [i for i in i.replace("+", " ").strip().split(" ") if i]:
if input_pattern == [i for i in i.replace("+", " ").strip().split(" ") if i]: # pragma: no cover
return True
return False

Expand Down Expand Up @@ -641,7 +641,7 @@ def quantize_input(self, model):
"""
scale = None
# quantize input only support tensorflow version > 2.1.0
if version1_lt_version2(tf.version.VERSION, "2.1.0"):
if version1_lt_version2(tf.version.VERSION, "2.1.0"): # pragma: no cover
logger.warning("Quantize input needs tensorflow 2.1.0 and newer.")
return model, scale

Expand Down Expand Up @@ -872,7 +872,7 @@ def precisions(self):
return self._precisions

@precisions.setter
def precisions(self, precisions):
def precisions(self, precisions): # pragma: no cover
"""Set precision."""
if not isinstance(precisions, list):
precisions = [precisions]
Expand All @@ -881,7 +881,7 @@ def precisions(self, precisions):
self._precisions = precisions

@staticmethod
def check_value(name, src, supported_type, supported_value=[]):
def check_value(name, src, supported_type, supported_value=[]): # pragma: no cover
"""Check if the given object is the given supported type and in the given supported value.

Example::
Expand Down Expand Up @@ -946,7 +946,7 @@ def _get_specified_version_cfg(self, data):
config = None

def _compare(version1, version2):
if parse_version(version1) == parse_version(version2):
if parse_version(version1) == parse_version(version2): # pragma: no cover
return 0
elif parse_version(version1) < parse_version(version2):
return -1
Expand Down Expand Up @@ -979,7 +979,7 @@ def _compare(version1, version2):
# convention. Replacing them with dot for version comparison.
sorted_list = [i.replace("-up", ".") for i in sorted_list]
sorted_list = sorted(sorted_list, key=cmp_to_key(_compare), reverse=True)
else:
else: # pragma: no cover
assert isinstance(sorted_list, str)
sorted_list = list(sorted_list.replace("-up", ".").split())
for i in sorted_list:
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def _one_shot_query(self):
def _update_cfg_with_usr_definition(self):
"""Add user defined precision configuration."""
tensorflow_config = TensorFlowConfig()
if tensorflow_config.precisions is not None:
if tensorflow_config.precisions is not None: # pragma: no cover
self.cur_config["precisions"]["names"] = ",".join(tensorflow_config.precisions)

def get_version(self):
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def get_fuse_patterns(self):
elif version1_gte_version2(tf.version.VERSION, "2.1.0"):
patterns["int8"] = tf_int8_pattern_list
patterns["uint8"] = tf_uint8_pattern_list
if self.itex_mode:
if self.itex_mode: # pragma: no cover
patterns["int8"].append("FusedBatchNormV3 + Relu")
patterns["int8"].append("FusedBatchNormV3 + LeakyRelu")
elif version1_eq_version2(tf.version.VERSION, "1.15.0-up3"): # pragma: no cover
Expand Down Expand Up @@ -1340,23 +1340,23 @@ def get_op_types_by_precision(self, precision):
tf.version.VERSION, "1.15.0-up3"
):
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["MatMul", "ConcatV2", "MaxPool", "AvgPool"] # pragma: no cover
if precision == "uint8":
if tf.version.VERSION in spr_base_verions:
return [key for key in self.cur_config["int8"][self.quant_mode].keys() if "Norm" not in key]
if version1_gte_version2(tf.version.VERSION, "2.1.0") or version1_eq_version2(
tf.version.VERSION, "1.15.0-up3"
):
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool", "DepthwiseConv2dNative"]
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"]
return ["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool"] # pragma: no cover
if precision == "bf16":
if tf.version.VERSION in spr_base_verions:
return self.cur_config[precision]
if version1_gte_version2(tf.version.VERSION, "2.1.0") or version1_eq_version2(
tf.version.VERSION, "1.15.0-up3"
):
return self.cur_config[precision]
return []
return [] # pragma: no cover

def get_mixed_precision_combination(self):
"""Get the valid mixed precisions.
Expand Down
28 changes: 15 additions & 13 deletions neural_compressor/tensorflow/quantization/utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self.scale_info.update({"bf16_ops": self.bf16_ops})
self.scale_info.update({"fp32_ops": self.fp32_ops})

if "backend" in self.model.kwargs:
if "backend" in self.model.kwargs: # pragma: no cover
self._sampling_model = Model(self.model._model, **self.model.kwargs)
else:
self._sampling_model = Model(
Expand Down Expand Up @@ -245,12 +245,12 @@ def _inference(self, model):
output_tensor = model.output_tensor
# TF table initialization: https://github.com/tensorflow/tensorflow/issues/8665
node_names = [node.name for node in sess.graph.as_graph_def().node]
if "init_all_tables" in node_names:
if "init_all_tables" in node_names: # pragma: no cover
init_table_op = sess.graph.get_operation_by_name("init_all_tables")
sess.run(init_table_op)

logger.info("Start sampling on calibration dataset.")
if hasattr(self.data_loader, "__len__") and len(self.data_loader) == 0:
if hasattr(self.data_loader, "__len__") and len(self.data_loader) == 0: # pragma: no cover
feed_dict = {}
_ = (
sess.run(output_tensor, feed_dict)
Expand Down Expand Up @@ -333,7 +333,7 @@ def _inference_llm(self, model):
feed_dict = {}
if len(input_tensor_names) == 1:
feed_dict[input_tensor_names[0]] = inputs
else:
else: # pragma: no cover
assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor"
for i, input_tensor_name in enumerate(input_tensor_names):
feed_dict[input_tensor_name] = inputs[i]
Expand Down Expand Up @@ -365,7 +365,7 @@ def _check_tf_version(self): # pragma: no cover
if version1_gte_version2(tf.version.VERSION, "2.9.0"):
is_supported_version = True

if tf.version.VERSION == "1.15.0-up3":
if tf.version.VERSION == "1.15.0-up3": # pragma: no cover
is_supported_version = True

if tf.version.VERSION in SPR_BASE_VERSIONS:
Expand Down Expand Up @@ -405,7 +405,7 @@ def _check_tf_version(self): # pragma: no cover
)
)

def _check_args(self):
def _check_args(self): # pragma: no cover
"""Check model's arguments."""
if (
self.model.workspace_path
Expand All @@ -429,7 +429,7 @@ def _gen_tmp_filenames(self):
self._tmp_model = self._fp32_model
else:
# to keep temp model
if "backend" in self.model.kwargs:
if "backend" in self.model.kwargs: # pragma: no cover
self._tmp_model = Model(self.model._model, **self.model.kwargs)
else:
self._tmp_model = Model(
Expand Down Expand Up @@ -707,7 +707,7 @@ def _generate_calibration_data(self, tmp_path, output_data, enable_kl_algo=False

if "backend" in self._tmp_model.kwargs:
model = Model(tmp_path, **self._tmp_model.kwargs)
else:
else: # pragma: no cover
model = Model(
tmp_path,
**self._tmp_model.kwargs,
Expand Down Expand Up @@ -755,7 +755,9 @@ def _freeze_requantization_ranges(self, additional_data=None):
self.scale_info.update(quantizev2_min)
self.scale_info.update(requant_min_max)

if "scale_propagation_max_pooling" in self.recipes and self.recipes["scale_propagation_max_pooling"]:
if (
"scale_propagation_max_pooling" in self.recipes and self.recipes["scale_propagation_max_pooling"]
): # pragma: no cover
self._tmp_graph_def = ScaleProPagationTransformer(self._tmp_graph_def).do_transformation()

if debug and not self.new_api:
Expand Down Expand Up @@ -817,7 +819,7 @@ def _fuse_requantize_with_fused_quantized_node(self):

self._tmp_model.graph_def = self._tmp_graph_def

def _post_clean(self):
def _post_clean(self): # pragma: no cover
"""Delete the temporarily files generated during the quantization process.

:return: None
Expand All @@ -840,7 +842,7 @@ def quantize_with_qdq_pattern(self):
self._insert_qdq_pairs()
self._convert_qdq()

except ValueError as e:
except ValueError as e: # pragma: no cover
logger.error("Fail to quantize graph due to {}.".format(str(e)))
self._tmp_model = None
raise
Expand Down Expand Up @@ -885,10 +887,10 @@ def _insert_qdq_pairs(self):
self.itex_mode,
).get_quantized_nodes()

if self.itex_mode:
if self.itex_mode: # pragma: no cover
self.quantized_node_info.extend(self._search_y_pattern_for_itex())

if self._enable_kl_op_names:
if self._enable_kl_op_names: # pragma: no cover
self._get_fp32_print_node_names(self._enable_kl_op_names)
self._generate_calibration_data(self._fp32_logged_model_path, self._fp32_print_data, True)

Expand Down
Loading
Loading