From f67e8613c409563f016c77e05a1acb969790cfc6 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Tue, 30 Apr 2024 10:38:59 +0800 Subject: [PATCH] Support PTQ on Keras3 (#1759) Signed-off-by: zehao-intel --- .../algorithms/static_quant/keras.py | 397 +++++++-------- .../tensorflow/keras/layers/__init__.py | 1 - .../tensorflow/keras/layers/conv2d.py | 423 ++++++++++++---- .../tensorflow/keras/layers/dense.py | 141 ++++-- .../keras/layers/depthwise_conv2d.py | 466 +++++++++++++----- .../tensorflow/keras/layers/pool2d.py | 191 ++++++- .../tensorflow/keras/layers/quantizer.py | 150 ------ .../keras/layers/separable_conv2d.py | 447 +++++++++++++---- .../quantization/algorithm_entry.py | 3 +- .../tensorflow/utils/__init__.py | 1 + neural_compressor/tensorflow/utils/utility.py | 10 + requirements_tf.txt | 2 +- test/3x/tensorflow/keras/test_config.py | 27 +- .../tensorflow/keras/test_model_wrappers.py | 9 +- .../newapi/test_graph_conv_fusion_newapi.py | 56 ++- .../test_graph_conv_requantize_fusion.py | 14 +- .../newapi/test_graph_depthwiseconv_fusion.py | 14 +- .../newapi/test_graph_fuse_pad_conv_fp32.py | 14 +- .../ptq/newapi/test_graph_qdq_bn_fusion.py | 74 ++- .../newapi/test_graph_qdq_concat_fusion.py | 8 +- .../ptq/newapi/test_graph_qdq_conv_fusion.py | 116 ++++- .../test_graph_qdq_depthwiseconv_fusion.py | 14 +- .../newapi/test_graph_qdq_new_conv_fusion.py | 8 +- .../newapi/test_graph_qdq_pooling_fusion.py | 14 +- .../quantization/ptq/test_bias_correction.py | 18 +- .../quantization/ptq/test_data_pipline.py | 15 +- .../quantization/ptq/test_fold_batch_norm.py | 50 +- .../ptq/test_get_estimator_graph.py | 7 + .../quantization/ptq/test_graph_concat.py | 8 +- .../ptq/test_graph_conv_as_output.py | 1 - .../ptq/test_graph_conv_fusion.py | 67 ++- .../quantization/ptq/test_graph_meta_pass.py | 32 +- .../quantization/ptq/test_graph_pad_conv.py | 32 +- .../ptq/test_graph_post_cse_optimize.py | 26 +- .../ptq/test_graph_switch_optimizer.py | 20 +- .../quantization/ptq/test_query_yaml.py | 8 +- .../quantization/test_smooth_quant.py | 14 +- test/3x/tensorflow/test_autotune.py | 6 +- 38 files changed, 2025 insertions(+), 879 deletions(-) delete mode 100644 neural_compressor/tensorflow/keras/layers/quantizer.py diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 79ed5464a1f..c4b15d847a3 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -28,18 +28,15 @@ from neural_compressor.common import logger from neural_compressor.common.utils import DEFAULT_WORKSPACE from neural_compressor.tensorflow.keras.layers import ( - DeQuantize, - FakeQuant, QAvgPool2D, QConv2D, QDense, QDepthwiseConv2D, QMaxPool2D, QSeparableConv2D, - Quantize, ) from neural_compressor.tensorflow.quantization.config import StaticQuantConfig -from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time +from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time, version1_gte_version2 class KerasAdaptor: @@ -57,9 +54,6 @@ class KerasAdaptor: ] custom_layers = { - "Quantize": Quantize, - "DeQuantize": DeQuantize, - "FakeQuant": FakeQuant, "QConv2D": QConv2D, "QDepthwiseConv2D": QDepthwiseConv2D, "QSeparableConv2D": QSeparableConv2D, @@ -91,9 +85,10 @@ def __init__(self, framework_specific_info): self.conv_format = {} self.fold_conv = [] + self.keras3 = True if version1_gte_version2(tf.__version__, "2.16.1") else False if not os.path.exists(DEFAULT_WORKSPACE): os.mkdir(DEFAULT_WORKSPACE) - self.tmp_dir = DEFAULT_WORKSPACE + "tmp_model" + self.tmp_dir = (DEFAULT_WORKSPACE + "tmp_model.keras") if self.keras3 else (DEFAULT_WORKSPACE + "tmp_model") def _check_itex(self): """Check if the IntelĀ® Extension for TensorFlow has been installed.""" @@ -153,7 +148,7 @@ def _check_quantize_format(self, model): for layer in model.layers: layer_name_mapping[layer.name] = layer for node in layer._outbound_nodes: - layer_name = node.outbound_layer.name + layer_name = node.operation.name if self.keras3 else node.outbound_layer.name if layer_name not in input_layer_dict: input_layer_dict[layer_name] = [layer.name] else: @@ -169,55 +164,54 @@ def _check_quantize_format(self, model): self.conv_format[layer.name] = "u8" break - def _fuse_bn(self, model): - """Fusing Batch Normalization.""" - fuse_bn_model = copy.deepcopy(model) - fp32_layers = fuse_bn_model.layers + def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers): + fuse_layers = [] + fused_bn_name = "" + for idx, layer in enumerate(fp32_layers): + if hasattr(layer, "_outbound_nodes"): + if layer.name == fused_bn_name: + continue + + if layer.name in self.conv_weights.keys(): + new_outbound_nodes = [] + conv_weight = self.conv_weights[layer.name] + for outbound_node in layer._outbound_nodes: + outbound_layer = outbound_node.operation + if outbound_layer.__class__.__name__ in ("BatchNormalization"): + fused_bn_name = outbound_layer.name + bn_weight = self.bn_weights[fused_bn_name] + self.layer_weights[layer.name] = fuse_conv_bn( + conv_weight, bn_weight, layer.__class__.__name__, outbound_layer.epsilon + ) + self.fold_conv.append(layer.name) + for node in outbound_layer._outbound_nodes: + new_outbound_nodes.append(node) + else: + new_outbound_nodes.append(outbound_node) + layer._outbound_nodes.clear() + for node in new_outbound_nodes: + layer._outbound_nodes.append(node) - def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): - assert conv_type in [ - "Conv2D", - "DepthwiseConv2D", - "SeparableConv2D", - ], "only support Conv2D, DepthwiseConv2D, SeparableConv2D..." - if len(bn_weight) > 3: - if conv_type == "DepthwiseConv2D": - gamma = bn_weight[0].reshape(1, 1, bn_weight[0].shape[0], 1) - var = bn_weight[3].reshape(1, 1, bn_weight[3].shape[0], 1) - else: - gamma = bn_weight[0].reshape(1, 1, 1, bn_weight[0].shape[0]) - var = bn_weight[3].reshape(1, 1, 1, bn_weight[3].shape[0]) - beta = bn_weight[1] - mean = bn_weight[2] + fuse_layers.append(layer) else: - gamma = 1.0 - beta = bn_weight[0] - mean = bn_weight[1] - if conv_type == "DepthwiseConv2D": - var = bn_weight[2].reshape(1, 1, bn_weight[2].shape[0], 1) + if ( + idx > 0 + and layer.__class__.__name__ == "BatchNormalization" + and fp32_layers[idx - 1].__class__.__name__ == "Conv2D" + ): + conv_name = fp32_layers[idx - 1].name + conv_weight = self.conv_weights[conv_name] + bn_weight = self.bn_weights[layer.name] + conv_type = fp32_layers[idx - 1].__class__.__name__ + + self.layer_weights[conv_name] = fuse_conv_bn(conv_weight, bn_weight, conv_type, layer.epsilon) + self.fold_conv.append(conv_name) else: - var = bn_weight[2].reshape(1, 1, 1, bn_weight[2].shape[0]) + fuse_layers.append(layer) - if len(conv_weight) == 1: - weight = conv_weight[0] - bias = np.zeros_like(beta) - elif len(conv_weight) == 2 and conv_type == "SeparableConv2D": - depth_weight = conv_weight[0] - weight = conv_weight[1] - bias = np.zeros_like(beta) - elif len(conv_weight) == 2 and conv_type != "SeparableConv2D": - weight = conv_weight[0] - bias = conv_weight[1] - elif len(conv_weight) == 3: - depth_weight = conv_weight[0] - weight = conv_weight[1] - bias = conv_weight[2] - scale_value = gamma / np.sqrt(var + eps) - weight = weight * scale_value - bias = beta + (bias - mean) * scale_value.reshape(-1) - bias = bias.reshape(-1) - return [depth_weight, weight, bias] if conv_type == "SeparableConv2D" else [weight, bias] + return fuse_layers + def _fuse_bn_keras2(self, fuse_conv_bn, fp32_layers): fuse_layers = [] for idx, layer in enumerate(fp32_layers): if hasattr(layer, "_inbound_nodes"): @@ -243,12 +237,14 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): else: for bound_node in layer._inbound_nodes: inbound_layer = bound_node.inbound_layers - if ( - not isinstance(inbound_layer, list) - and inbound_layer.name in self.bn_weights.keys() - and inbound_layer._inbound_nodes[0].inbound_layers.name in self.conv_weights.keys() - ): - new_bound_nodes.append(bn_inbound_node) + if inbound_layer in self.bn_weights.keys(): + for bn_inbound_node in inbound_layer._inbound_nodes: + bn_inbound_layer = bn_inbound_node.inbound_layers + if bn_inbound_layer.name in self.conv_weights.keys(): + new_bound_nodes.append(bn_inbound_node) + else: + if bound_node not in new_bound_nodes: + new_bound_nodes.append(bound_node) else: new_bound_nodes.append(bound_node) @@ -274,7 +270,62 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): else: fuse_layers.append(layer) - for idx, layer in enumerate(fuse_layers): + return fuse_layers + + def _fuse_bn(self, model): + """Fusing Batch Normalization.""" + model.save(self.tmp_dir) + fuse_bn_model = tf.keras.models.load_model(self.tmp_dir) + fp32_layers = fuse_bn_model.layers + + def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): + assert conv_type in [ + "Conv2D", + "DepthwiseConv2D", + "SeparableConv2D", + ], "only support Conv2D, DepthwiseConv2D, SeparableConv2D..." + if len(bn_weight) > 3: + if conv_type == "DepthwiseConv2D": + gamma = bn_weight[0].reshape(1, 1, bn_weight[0].shape[0], 1) + var = bn_weight[3].reshape(1, 1, bn_weight[3].shape[0], 1) + else: + gamma = bn_weight[0].reshape(1, 1, 1, bn_weight[0].shape[0]) + var = bn_weight[3].reshape(1, 1, 1, bn_weight[3].shape[0]) + beta = bn_weight[1] + mean = bn_weight[2] + else: + gamma = 1.0 + beta = bn_weight[0] + mean = bn_weight[1] + if conv_type == "DepthwiseConv2D": + var = bn_weight[2].reshape(1, 1, bn_weight[2].shape[0], 1) + else: + var = bn_weight[2].reshape(1, 1, 1, bn_weight[2].shape[0]) + + if len(conv_weight) == 1: + weight = conv_weight[0] + bias = np.zeros_like(beta) + elif len(conv_weight) == 2 and conv_type == "SeparableConv2D": + depth_weight = conv_weight[0] + weight = conv_weight[1] + bias = np.zeros_like(beta) + elif len(conv_weight) == 2 and conv_type != "SeparableConv2D": + weight = conv_weight[0] + bias = conv_weight[1] + elif len(conv_weight) == 3: + depth_weight = conv_weight[0] + weight = conv_weight[1] + bias = conv_weight[2] + scale_value = gamma / np.sqrt(var + eps) + weight = weight * scale_value + bias = beta + (bias - mean) * scale_value.reshape(-1) + bias = bias.reshape(-1) + return [depth_weight, weight, bias] if conv_type == "SeparableConv2D" else [weight, bias] + + fuse_bn_function = self._fuse_bn_keras3 if self.keras3 else self._fuse_bn_keras2 + fused_layers = fuse_bn_function(fuse_conv_bn, fp32_layers) + + for idx, layer in enumerate(fused_layers): if ( layer.__class__.__name__ in ("Conv2D", "DepthwiseConv2D", "SeparableConv2D") and layer.name in self.fold_conv @@ -284,10 +335,10 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): conv_layer = type(layer).from_config(conv_config) for node in layer._outbound_nodes: conv_layer._outbound_nodes.append(node) - fuse_layers[idx] = conv_layer + fused_layers[idx] = conv_layer bn_surgery = KerasSurgery(model) - bn_fused_model = bn_surgery.fuse_bn_layers(fuse_layers, self.conv_weights.keys()) + bn_fused_model = bn_surgery.convert(fused_layers, self.conv_weights.keys()) bn_fused_model = self._set_weights(bn_fused_model, self.layer_weights) bn_fused_model.save(self.tmp_dir) @@ -316,8 +367,8 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): converted_model = self.convert_bf16() return converted_model - if self.backend == "itex": - self._check_itex() + # if self.backend == "itex": + # self._check_itex() logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) @@ -333,32 +384,31 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): ) ) - fq_layers_dict = {} - fq_output_layers = {} - for idx, layer in enumerate(self.pre_optimized_model.layers): + from neural_compressor.tensorflow.keras.layers import layer_initializer_dict + + q_layer_dict = {} + for layer in self.pre_optimized_model.layers: if layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: op_config = self.quantize_config["op_wise_config"][layer.name] - mode = "per_channel" if op_config[0] else "per_tensor" - fake_q_name = "fake_quant_" + str(idx) - fake_q_layer = FakeQuant(name=fake_q_name, T=self.conv_format[layer.name], mode="per_tensor") - fq_layers_dict[layer.name] = [fake_q_layer] - fq_output_layers[fake_q_layer.name] = layer.name - self.pre_optimized_model.save(self.tmp_dir) - - fq_surgery = KerasSurgery(self.pre_optimized_model) - calibration_model = fq_surgery.insert_quant_layers(fq_layers_dict) + granularity = "per_channel" if op_config[0] else "per_tensor" + q_layer_class = "Q" + layer.__class__.__name__ + q_config = {"T": self.conv_format[layer.name], "granularity": granularity} + q_layer = layer_initializer_dict[q_layer_class](layer, q_config) + q_layer_dict[layer.name] = q_layer + + calib_surgery = KerasSurgery(self.pre_optimized_model) + calibration_model = calib_surgery.convert(q_layer_dict=q_layer_dict) calibration_model = self._set_weights(calibration_model, self.layer_weights) quantized_model = self._calibrate( calibration_model, dataloader, self.quantize_config["calib_iteration"], - fq_output_layers, ) return quantized_model - def _calibrate(self, model, dataloader, calib_interation, fq_output_layers): + def _calibrate(self, model, dataloader, calib_interation): """Apply calibration. Args: @@ -371,51 +421,32 @@ def _calibrate(self, model, dataloader, calib_interation, fq_output_layers): # run eagerly to fetch the numpy min/max results = {} model.compile(run_eagerly=True) - for idx, (inputs, labels) in enumerate(dataloader): + for idx, (inputs, _) in enumerate(dataloader): _ = model.predict_on_batch(inputs) - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - layers = config["layers"] - for layer in layers: - if layer["class_name"] == "FakeQuant": - min_value = layer["config"]["min_value"] - max_value = layer["config"]["max_value"] - assert min_value < max_value, "The min value must be lower than the max value in quantization." - - if layer["config"]["name"] not in results: - results[layer["config"]["name"]] = {"min": [min_value], "max": [max_value]} + for layer in model.layers: + if ( + layer.__class__.__name__[1:] in self.supported_op + and layer.name in self.quantize_config["op_wise_config"] + ): + min_value = layer.act_min_value.numpy() + max_value = layer.act_max_value.numpy() + + if layer.name not in results: + results[layer.name] = {"min": [min_value], "max": [max_value]} else: - results[layer["config"]["name"]]["min"].append(min_value) - results[layer["config"]["name"]]["max"].append(max_value) + results[layer.name]["min"].append(min_value) + results[layer.name]["max"].append(max_value) if idx + 1 == calib_interation: break - qdq_layer_nums = 0 - qdq_layers_dict = {} - quantized_layers_dict = {} for idx, layer in enumerate(model.layers): - if layer.__class__.__name__ == "FakeQuant": - min_value = min(results[layer.name]["min"]) - max_value = max(results[layer.name]["max"]) - - quantize_layer = Quantize( - name="quantize_" + str(qdq_layer_nums), - min_range=min_value, - max_range=max_value, - T=layer.T, - ) - dequantize_layer = DeQuantize( - name="dequantize_" + str(qdq_layer_nums), - min_range=min_value, - max_range=max_value, - ) - - qdq_layer_nums += 1 - output_layer_name = fq_output_layers[layer.name] - qdq_layers_dict[output_layer_name] = [quantize_layer, dequantize_layer] - elif layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: - # index 0 is weight, index 1 is bias - q_layer_class = "Q" + layer.__class__.__name__ + if ( + layer.__class__.__name__[1:] in self.supported_op + and layer.name in self.quantize_config["op_wise_config"] + ): + layer.act_min_value = np.min(results[layer.name]["min"]) + layer.act_max_value = np.max(results[layer.name]["max"]) + layer.quant_status = "quantize" # for layers that have weights if layer.name in self.layer_weights: kernel = self.layer_weights[layer.name][0] @@ -425,24 +456,15 @@ def _calibrate(self, model, dataloader, calib_interation, fq_output_layers): channel_size = kernel.shape[-1] kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1) - layer.min_value = np.min(kernel_channel, axis=1).tolist() - layer.max_value = np.max(kernel_channel, axis=1).tolist() + layer.weight_min_value = np.min(kernel_channel, axis=1).tolist() + layer.weight_max_value = np.max(kernel_channel, axis=1).tolist() else: # default value, but never expected to be used # cause no kernel weights for this layer - layer.min_value = [-10000] - layer.max_value = [10000] - - from neural_compressor.tensorflow.keras.layers import layer_initializer_dict - - q_layer = layer_initializer_dict[q_layer_class](layer) - quantized_layers_dict[layer.name] = q_layer + layer.weight_min_value = [-10000] + layer.weight_max_value = [10000] - qdq_surgery = KerasSurgery(self.pre_optimized_model) - quantized_model = qdq_surgery.insert_quant_layers(qdq_layers_dict, quantized_layers_dict) - quantized_model = self._set_weights(quantized_model, self.layer_weights) - - quantized_model.save(self.tmp_dir) + model.save(self.tmp_dir) quantized_model = tf.keras.models.load_model(self.tmp_dir) return quantized_model @@ -456,7 +478,6 @@ def evaluate( metrics=None, measurer=None, iteration=-1, - tensorboard=False, fp32_baseline=False, ): """The function is used to run evaluation on validation dataset. @@ -468,7 +489,6 @@ def evaluate( metric (object, optional): Depends on model category. Defaults to None. measurer (object, optional): for precise benchmark measurement. iteration(int, optional): control steps of mini-batch - tensorboard (boolean, optional): for tensorboard inspect tensor. fp32_baseline (boolean, optional): only for compare_label=False pipeline """ # use keras object @@ -783,85 +803,67 @@ def __init__(self, model): Args: model: the model to be modified. """ + import shutil + self.model_outputs = [] - self.model = copy.deepcopy(model) + self.keras3 = True if version1_gte_version2(tf.__version__, "2.16.1") else False + self.tmp_dir = (DEFAULT_WORKSPACE + "tmp_model.keras") if self.keras3 else (DEFAULT_WORKSPACE + "tmp_model") + model.save(self.tmp_dir) + self.model = tf.keras.models.load_model(self.tmp_dir) + shutil.rmtree(self.tmp_dir, ignore_errors=True) - def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): + def _parse_inputs(self, BN_fused_layers=None, conv_names=None): """Create a input_layer_dict from model. Args: - fuse_layers: The layers in which fused BNs have been excluded, default to be None. - conv_weights_keys: The names of conv layers where BNs are going to be fused, default to be None. + BN_fused_layers: The layers in which BN layers have been fused. + conv_names: The name list of conv layers where BNs are fused. Returns: input_layer_dict: The dict that mapping for layer names to their input layer names. """ input_layer_dict = {} - layers = fuse_layers if fuse_layers else self.model.layers + layers = BN_fused_layers if BN_fused_layers else self.model.layers for layer in layers: for node in layer._outbound_nodes: - out_layer = node.outbound_layer + out_layer = node.operation if self.keras3 else node.outbound_layer out_layer_names = [out_layer.name] - if ( - conv_weights_keys - and out_layer.__class__.__name__ in ("BatchNormalization") - and layer.name in conv_weights_keys - ): - out_layer_names = [node.outbound_layer.name for node in out_layer._outbound_nodes] + if conv_names and out_layer.__class__.__name__ in ("BatchNormalization") and layer.name in conv_names: + out_layer_names = ( + [node.operation.name for node in out_layer._outbound_nodes] + if self.keras3 + else [node.outbound_layer.name for node in out_layer._outbound_nodes] + ) for out_layer_name in out_layer_names: if out_layer_name not in input_layer_dict: - input_layer_dict[out_layer_name] = [layer.name] + input_layer_dict[out_layer_name] = set([layer.name]) else: - input_layer_dict[out_layer_name].append(layer.name) - - return input_layer_dict + input_layer_dict[out_layer_name].add(layer.name) - def fuse_bn_layers(self, fuse_layers, conv_weights_keys): - """Fuse BN layers and rebuild the model. + for key in input_layer_dict.keys(): + input_layer_dict[key] = list(input_layer_dict[key]) - Args: - fuse_layers: The layers in which fused BNs have been excluded. - conv_weights_keys: The names of conv layers where BNs are going to be fused. - """ - self.input_layer_dict = self._create_input_dict(fuse_layers, conv_weights_keys) - output_tensor_dict = {"keras.Input": self.model.input} - - for idx, layer in enumerate(fuse_layers): - if layer.__class__.__name__ == "InputLayer": - output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] - continue - - input_tensors = ( - output_tensor_dict["keras.Input"] - if idx == 0 - else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] - ) - - while isinstance(input_tensors, list) and len(input_tensors) == 1: - input_tensors = input_tensors[0] - - x = layer(input_tensors) - - output_tensor_dict[layer.name] = x - if layer.name in self.model.output_names: - self.model_outputs.append(x) + try: + model_input = self.model.input + except ValueError: + model_input = self.model.inputs[0] - return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) + return input_layer_dict, model_input - def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): - """Insert FakeQuant or QDQ layers before the target layers and replace - Keras layers to Quantized layers. + def convert(self, BN_fused_layers=None, conv_names=None, q_layer_dict=None): + """Generate optimized model by fusing BN layers or replacing Keras layers to custom quantized layers. Args: - qdq_layer_dict: The dict mapping from layers to be quantized to the FakeQuant layer or QDQ layers - that are going to be inserted before them. - q_layer_dict: The dict mapping from layers to be replacement to the quantized layers. + BN_fused_layers: The layers in which BN layers have been fused. + conv_names: The name list of conv layers where BNs are fused. + q_layer_dict: The dict mapping from keras layers to custom quantized layers. """ - self.input_layer_dict = self._create_input_dict() - output_tensor_dict = {"keras.Input": self.model.input} + input_layer_dict, model_input = self._parse_inputs(BN_fused_layers, conv_names) + output_tensor_dict = {"keras.Input": model_input} + layers = BN_fused_layers if BN_fused_layers else self.model.layers - for idx, layer in enumerate(self.model.layers): + for idx, layer in enumerate(layers): if layer.__class__.__name__ == "InputLayer": output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] continue @@ -869,22 +871,23 @@ def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): input_tensors = ( output_tensor_dict["keras.Input"] if idx == 0 - else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + else [output_tensor_dict[input_layer] for input_layer in input_layer_dict[layer.name]] ) + while isinstance(input_tensors, list) and len(input_tensors) == 1: input_tensors = input_tensors[0] - if layer.name in qdq_layer_dict: - x = input_tensors - for inserted_layer in qdq_layer_dict[layer.name]: - x = inserted_layer(x) - cur_layer = layer if not q_layer_dict else q_layer_dict[layer.name] - x = cur_layer(x) - else: - x = layer(input_tensors) + if self.keras3: + layer._inbound_nodes.clear() + cur_layer = q_layer_dict[layer.name] if q_layer_dict and layer.name in q_layer_dict else layer + x = cur_layer(input_tensors) output_tensor_dict[layer.name] = x - if layer.name in self.model.output_names: + + if not isinstance(self.model, tf.keras.models.Sequential) and layer.name in self.model.output_names: self.model_outputs.append(x) + if not self.model_outputs: + self.model_outputs.append(x) + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) diff --git a/neural_compressor/tensorflow/keras/layers/__init__.py b/neural_compressor/tensorflow/keras/layers/__init__.py index 0b4fe9030ac..13b83950c97 100644 --- a/neural_compressor/tensorflow/keras/layers/__init__.py +++ b/neural_compressor/tensorflow/keras/layers/__init__.py @@ -19,6 +19,5 @@ from neural_compressor.tensorflow.keras.layers.dense import QDense from neural_compressor.tensorflow.keras.layers.depthwise_conv2d import QDepthwiseConv2D from neural_compressor.tensorflow.keras.layers.pool2d import QAvgPool2D, QMaxPool2D -from neural_compressor.tensorflow.keras.layers.quantizer import DeQuantize, FakeQuant, Quantize from neural_compressor.tensorflow.keras.layers.separable_conv2d import QSeparableConv2D from neural_compressor.tensorflow.keras.layers.layer_initializer import layer_initializer_dict diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 0a4852d2027..426b1777b42 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -18,107 +18,340 @@ import json import tensorflow as tf -from tensorflow import quantization from tensorflow.keras import activations, constraints, initializers, regularizers from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras import ops + from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401 +if version1_gte_version2(tf.__version__, "2.16.1"): -class QConv2D(Conv): - def __init__( - self, - name, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - groups=1, - activation=None, - use_bias=True, - kernel_initializer="glorot_uniform", - bias_initializer="zeros", - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, - **kwargs - ): - super(QConv2D, self).__init__( - name=name, - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - groups=groups, - activation=activations.get(activation), - use_bias=use_bias, - kernel_initializer=initializers.get(kernel_initializer), - bias_initializer=initializers.get(bias_initializer), - kernel_regularizer=regularizers.get(kernel_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - kernel_constraint=constraints.get(kernel_constraint), - bias_constraint=constraints.get(bias_constraint), + class QConv2D(BaseConv): + def __init__( + self, + name, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, + **kwargs + ): + super(QConv2D, self).__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + activation=activations.get(activation), + use_bias=use_bias, + kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + kernel_constraint=constraints.get(kernel_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib" and not isinstance(inputs, tf.keras.KerasTensor): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + kernel = self.kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + kernel_size = self.kernel.shape[-1] + + if not self.weight_min_value: + self.weight_min_value = [-10000] * kernel_size + if not self.weight_max_value: + self.weight_max_value = [10000] * kernel_size + + # add the Q/DQ here + kernel, _, _ = tf.quantization.quantize( + self.kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = tf.quantization.dequantize( + kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + + outputs = self.convolution_op( + inputs, + kernel, + ) + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + +else: + + class QConv2D(Conv): + def __init__( + self, + name, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs - ) - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - kernel_size = self.kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * kernel_size - if not self.max_value: - self.max_value = [10000] * kernel_size - - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - -def initialize_int8_conv2d(fp32_layer): + ): + super(QConv2D, self).__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + activation=activations.get(activation), + use_bias=use_bias, + kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + kernel_constraint=constraints.get(kernel_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib": + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + kernel = self.kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + kernel_size = self.kernel.shape[-1] + + if not self.weight_min_value: + self.weight_min_value = [-10000] * kernel_size + if not self.weight_max_value: + self.weight_max_value = [10000] * kernel_size + + # add the Q/DQ here + kernel, _, _ = tf.quantization.quantize( + self.kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = tf.quantization.dequantize( + kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + outputs = tf.keras.backend.conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + + +def initialize_int8_conv2d(fp32_layer, q_config): kwargs = fp32_layer.get_config() if "name" in kwargs: @@ -155,10 +388,6 @@ def initialize_int8_conv2d(fp32_layer): del kwargs["kernel_constraint"] if "bias_constraint" in kwargs: del kwargs["bias_constraint"] - if "min_value" in kwargs: - del kwargs["min_value"] - if "max_value" in kwargs: - del kwargs["max_value"] return QConv2D( name=fp32_layer.name, @@ -178,7 +407,7 @@ def initialize_int8_conv2d(fp32_layer): activity_regularizer=fp32_layer.activity_regularizer, kernel_constraint=fp32_layer.kernel_constraint, bias_constraint=fp32_layer.bias_constraint, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) diff --git a/neural_compressor/tensorflow/keras/layers/dense.py b/neural_compressor/tensorflow/keras/layers/dense.py index 61dfda2a2b8..4e97cbfb7a7 100644 --- a/neural_compressor/tensorflow/keras/layers/dense.py +++ b/neural_compressor/tensorflow/keras/layers/dense.py @@ -18,10 +18,11 @@ import json import tensorflow as tf -from tensorflow import quantization from tensorflow.keras import activations, backend, constraints, initializers, regularizers from tensorflow.keras.layers import Dense +from neural_compressor.tensorflow.utils import version1_gte_version2 + class QDense(Dense): def __init__( @@ -37,8 +38,17 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - min_value=None, - max_value=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs ): super(QDense, self).__init__( @@ -55,34 +65,75 @@ def __init__( bias_constraint=bias_constraint, **kwargs ) - self.min_value = min_value - self.max_value = max_value + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis def call(self, inputs): - kernel_size = self.kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * kernel_size - if not self.max_value: - self.max_value = [10000] * kernel_size - - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.kernel, - self.min_value, - self.max_value, - tf.qint8, - axis=1, - mode="SCALED", - ) + if self.quant_status == "calib" and not ( + version1_gte_version2(tf.__version__, "2.16.1") and isinstance(inputs, tf.keras.KerasTensor) + ): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + kernel = self.kernel + elif self.quant_status == "quantize": + assert self.act_min_value is not None, "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + kernel_size = self.kernel.shape[-1] + + if not self.weight_min_value: + self.weight_min_value = [-10000] * kernel_size + if not self.weight_max_value: + self.weight_max_value = [10000] * kernel_size + + # add the Q/DQ here + kernel, _, _ = tf.quantization.quantize( + self.kernel, + self.weight_min_value, + self.weight_max_value, + tf.qint8, + axis=1, + mode="SCALED", + ) + kernel = tf.quantization.dequantize( + kernel, + self.weight_min_value, + self.weight_max_value, + axis=1, + mode="SCALED", + ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=1, - mode="SCALED", - ) outputs = tf.keras.backend.dot(inputs, kernel) if self.use_bias: @@ -91,8 +142,32 @@ def call(self, inputs): outputs = self.activation(outputs) return outputs + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QDense, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + -def initialize_int8_dense(fp32_layer): +def initialize_int8_dense(fp32_layer, q_config): kwargs = fp32_layer.get_config() if "name" in kwargs: @@ -117,10 +192,6 @@ def initialize_int8_dense(fp32_layer): del kwargs["kernel_constraint"] if "bias_constraint" in kwargs: del kwargs["bias_constraint"] - if "min_value" in kwargs: - del kwargs["min_value"] - if "max_value" in kwargs: - del kwargs["max_value"] q_layer = QDense( name=fp32_layer.name, @@ -134,8 +205,8 @@ def initialize_int8_dense(fp32_layer): activity_regularizer=fp32_layer.activity_regularizer, kernel_constraint=fp32_layer.kernel_constraint, bias_constraint=fp32_layer.bias_constraint, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index a3e6dd9b2f4..683c774b2fe 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,135 +18,361 @@ import json import tensorflow as tf -from tensorflow import quantization from tensorflow.keras import activations, constraints, initializers, regularizers from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 +if version1_gte_version2(tf.__version__, "2.16.1"): -class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, - **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, + class QDepthwiseConv2D(BaseDepthwiseConv): + def __init__( + self, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, + **kwargs + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib" and not isinstance(inputs, tf.keras.KerasTensor): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + kernel = self.kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + # add the Q/DQ here + kernel, _, _ = tf.quantization.quantize( + self.kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = tf.quantization.dequantize( + kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + + input_channel = self._get_input_channel(inputs.shape) + outputs = ops.depthwise_conv( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) + else: + bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QDepthwiseConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + +else: + + class QDepthwiseConv2D(DepthwiseConv): + def __init__( + self, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs - ) - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - depthwise_kernel_size = self.depthwise_kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * depthwise_kernel_size - if not self.max_value: - self.max_value = [10000] * depthwise_kernel_size - - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length( - rows, - self.kernel_size[0], - self.padding, - self.strides[0], - self.dilation_rate[0], - ) - cols = conv_utils.conv_output_length( - cols, - self.kernel_size[1], - self.padding, - self.strides[1], - self.dilation_rate[1], - ) - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) - - -def initialize_int8_depthwise_conv2d(fp32_layer): + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib": + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + depthwise_kernel = self.depthwise_kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + # add the Q/DQ here + depthwise_kernel, _, _ = tf.quantization.quantize( + self.depthwise_kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = tf.quantization.dequantize( + depthwise_kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + + outputs = tf.keras.backend.depthwise_conv2d( + inputs, + depthwise_kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QDepthwiseConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == "channels_last": + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length( + rows, + self.kernel_size[0], + self.padding, + self.strides[0], + self.dilation_rate[0], + ) + cols = conv_utils.conv_output_length( + cols, + self.kernel_size[1], + self.padding, + self.strides[1], + self.dilation_rate[1], + ) + if self.data_format == "channels_first": + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == "channels_last": + return (input_shape[0], rows, cols, out_filters) + + +def initialize_int8_depthwise_conv2d(fp32_layer, q_config): kwargs = fp32_layer.get_config() q_name = fp32_layer.name @@ -204,7 +430,7 @@ def initialize_int8_depthwise_conv2d(fp32_layer): activity_regularizer=fp32_layer.activity_regularizer, depthwise_constraint=fp32_layer.depthwise_constraint, bias_constraint=fp32_layer.bias_constraint, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) diff --git a/neural_compressor/tensorflow/keras/layers/pool2d.py b/neural_compressor/tensorflow/keras/layers/pool2d.py index 05a028ecc83..ce81fc2377b 100644 --- a/neural_compressor/tensorflow/keras/layers/pool2d.py +++ b/neural_compressor/tensorflow/keras/layers/pool2d.py @@ -18,10 +18,10 @@ import json import tensorflow as tf -from tensorflow import quantization -from tensorflow.keras import activations, backend, constraints, initializers, regularizers from tensorflow.keras.layers import AveragePooling2D, MaxPooling2D +from neural_compressor.tensorflow.utils import version1_gte_version2 + class QAvgPool2D(AveragePooling2D): def __init__( @@ -31,15 +31,90 @@ def __init__( strides=None, padding="valid", data_format=None, - min_value=-10000, - max_value=10000, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs ): super(QAvgPool2D, self).__init__( name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) - self.min_value = min_value - self.max_value = max_value + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def __call__(self, inputs): + if self.quant_status == "calib" and not ( + version1_gte_version2(tf.__version__, "2.16.1") and isinstance(inputs, tf.keras.KerasTensor) + ): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + elif self.quant_status == "quantize": + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + return super(QAvgPool2D, self).__call__(inputs) + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QAvgPool2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config class QMaxPool2D(MaxPooling2D): @@ -50,18 +125,94 @@ def __init__( strides=None, padding="valid", data_format=None, - min_value=-10000, - max_value=10000, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs ): super(QMaxPool2D, self).__init__( name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) - self.min_value = min_value - self.max_value = max_value + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def __call__(self, inputs): + if self.quant_status == "calib" and not ( + version1_gte_version2(tf.__version__, "2.16.1") and isinstance(inputs, tf.keras.KerasTensor) + ): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + elif self.quant_status == "quantize": + assert self.act_min_value is not None, "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + return super(QMaxPool2D, self).__call__(inputs) + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QMaxPool2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config -def initialize_int8_avgpool(fp32_layer): +def initialize_int8_avgpool(fp32_layer, q_config): kwargs = fp32_layer.get_config() if "name" in kwargs: @@ -74,10 +225,6 @@ def initialize_int8_avgpool(fp32_layer): del kwargs["padding"] if "data_format" in kwargs: del kwargs["data_format"] - if "min_value" in kwargs: - del kwargs["min_value"] - if "max_value" in kwargs: - del kwargs["max_value"] q_layer = QAvgPool2D( name=fp32_layer.name, @@ -85,15 +232,15 @@ def initialize_int8_avgpool(fp32_layer): strides=fp32_layer.strides, padding=fp32_layer.padding, data_format=fp32_layer.data_format, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) return q_layer -def initialize_int8_maxpool(fp32_layer): +def initialize_int8_maxpool(fp32_layer, q_config): kwargs = fp32_layer.get_config() if "name" in kwargs: @@ -106,10 +253,6 @@ def initialize_int8_maxpool(fp32_layer): del kwargs["padding"] if "data_format" in kwargs: del kwargs["data_format"] - if "min_value" in kwargs: - del kwargs["min_value"] - if "max_value" in kwargs: - del kwargs["max_value"] q_layer = QMaxPool2D( name=fp32_layer.name, @@ -117,8 +260,8 @@ def initialize_int8_maxpool(fp32_layer): strides=fp32_layer.strides, padding=fp32_layer.padding, data_format=fp32_layer.data_format, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) diff --git a/neural_compressor/tensorflow/keras/layers/quantizer.py b/neural_compressor/tensorflow/keras/layers/quantizer.py deleted file mode 100644 index a6e31fc6a5c..00000000000 --- a/neural_compressor/tensorflow/keras/layers/quantizer.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import tensorflow as tf -from tensorflow.keras.layers import Layer - - -class FakeQuant(Layer): - def __init__(self, mode="per_tensor", T="s8", **kwargs): - super(FakeQuant, self).__init__(**kwargs) - self.mode = mode - self.T = T - self.axis = 1 if mode == "per_channel" else 0 - self.min_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) - self.max_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) - - def call(self, inputs): - if self.mode == "per_tensor": - self.min_value = tf.math.reduce_min(inputs) - self.max_value = tf.math.reduce_max(inputs) - else: - self.min_value = tf.math.reduce_min(inputs, axis=self.axis) - self.max_value = tf.math.reduce_max(inputs, axis=self.axis) - - return inputs - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - return input_shape - - @classmethod - def from_config(cls, config): - return cls(**config) - - def get_config(self): - return { - "mode": self.mode, - "min_value": self.min_value.numpy(), - "max_value": self.max_value.numpy(), - "T": self.T, - "name": self.name, - } - - -class Quantize(Layer): - def __init__( - self, - min_range, - max_range, - T="s8", - mode="SCALED", - round_mode="HALF_AWAY_FROM_ZERO", - narrow_range=False, - axis=None, - **kwargs - ): - super(Quantize, self).__init__(**kwargs) - T_map = {"s8": tf.qint8, "u8": tf.quint8} - self.min_range = float(min_range) - self.max_range = float(max_range) - self.T = T_map[T] - self.mode = mode - self.round_mode = round_mode - self.narrow_range = narrow_range - self.axis = axis - - def call(self, inputs): - outputs, _, _ = tf.quantization.quantize( - inputs, - self.min_range, - self.max_range, - self.T, - mode=self.mode, - round_mode=self.round_mode, - narrow_range=self.narrow_range, - axis=self.axis, - ) - return outputs - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - return input_shape - - def get_config(self): - return { - "min_range": self.min_range, - "max_range": self.max_range, - "T": self.T, - "mode": self.mode, - "round_mode": self.round_mode, - "narrow": self.narrow_range, - "axis": self.axis, - } - - @classmethod - def from_config(cls, config): - return cls(**config) - - -class DeQuantize(Layer): - def __init__(self, min_range, max_range, mode="SCALED", narrow_range=False, axis=None, **kwargs): - super(DeQuantize, self).__init__(**kwargs) - self.min_range = min_range - self.max_range = max_range - self.mode = mode - self.narrow_range = narrow_range - self.axis = axis - - def call(self, inputs): - return tf.quantization.dequantize( - inputs, - float(self.min_range), - float(self.max_range), - mode=self.mode, - narrow_range=self.narrow_range, - axis=self.axis, - ) - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - return input_shape - - def get_config(self): - return { - "min_range": self.min_range, - "max_range": self.max_range, - "mode": self.mode, - "narrow": self.narrow_range, - "axis": self.axis, - "dtype": self.dtype, - } - - @classmethod - def from_config(cls, config): - return cls(**config) diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 7df66d9db49..05ee3a62c72 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,123 +18,354 @@ import json import tensorflow as tf -from tensorflow import quantization from tensorflow.keras import activations, constraints, initializers, regularizers from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.src.utils import conv_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.utils import conv_utils # pylint: disable=E0401 +if version1_gte_version2(tf.__version__, "2.16.1"): -class QSeparableConv2D(SeparableConv): - def __init__( - self, - name, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, - **kwargs - ): - super().__init__( - name=name, - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), + class QSeparableConv2D(BaseSeparableConv): + def __init__( + self, + filters, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, + **kwargs + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib" and not isinstance(inputs, tf.keras.KerasTensor): + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + depthwise_kernel = self.depthwise_kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + # (TODO) it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = tf.quantization.quantize( + self.depthwise_kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = tf.quantization.dequantize( + depthwise_kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + + outputs = ops.separable_conv( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QSeparableConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + +else: + + class QSeparableConv2D(SeparableConv): + def __init__( + self, + filters, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + act_min_value=None, + act_max_value=None, + weight_min_value=None, + weight_max_value=None, + granularity="per_tensor", + quant_status="calib", + quant_mode="SCALED", + quant_T="s8", + quant_round_mode="HALF_AWAY_FROM_ZERO", + quant_narrow_range=False, + quant_axis=None, **kwargs - ) - - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - depthwise_kernel_size = self.depthwise_kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * depthwise_kernel_size - if not self.max_value: - self.max_value = [10000] * depthwise_kernel_size - - # TODO it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - if self.data_format == "channels_last": - strides = (1,) + self.strides + (1,) - else: - strides = (1, 1) + self.strides - - outputs = tf.compat.v1.nn.separable_conv2d( - inputs, - depthwise_kernel, - self.pointwise_kernel, - strides=strides, - padding=self.padding.upper(), - rate=self.dilation_rate, - data_format=conv_utils.convert_data_format(self.data_format, ndim=4), - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - -def initialize_int8_separable_conv2d(fp32_layer): + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + T_map = {"s8": tf.qint8, "u8": tf.quint8} + self.weight_min_value = weight_min_value + self.weight_max_value = weight_max_value + self.act_min_value = act_min_value + self.act_max_value = act_max_value + self.granularity = granularity + self.quant_status = quant_status + self.quant_mode = quant_mode + self.quant_T = T_map[quant_T] + self.quant_round_mode = quant_round_mode + self.quant_narrow_range = quant_narrow_range + self.quant_axis = quant_axis + + def call(self, inputs): + if self.quant_status == "calib": + if self.granularity == "per_tensor": + self.act_min_value = tf.math.reduce_min(inputs) + self.act_max_value = tf.math.reduce_max(inputs) + else: + self.act_min_value = tf.math.reduce_min(inputs, axis=1) + self.act_max_value = tf.math.reduce_max(inputs, axis=1) + depthwise_kernel = self.depthwise_kernel + elif self.quant_status == "quantize": + assert ( + self.act_min_value is not None + ), "Invalid activation min-max values, please check calibration process" + inputs, _, _ = tf.quantization.quantize( + inputs, + self.act_min_value, + self.act_max_value, + self.quant_T, + mode=self.quant_mode, + round_mode=self.quant_round_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + inputs = tf.quantization.dequantize( + inputs, + self.act_min_value, + self.act_max_value, + mode=self.quant_mode, + narrow_range=self.quant_narrow_range, + axis=self.quant_axis, + ) + + # (TODO) it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = tf.quantization.quantize( + self.depthwise_kernel, self.weight_min_value, self.weight_max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = tf.quantization.dequantize( + depthwise_kernel, + self.weight_min_value, + self.weight_max_value, + axis=3, + mode="SCALED", + ) + + if self.data_format == "channels_last": + strides = (1,) + self.strides + (1,) + else: + strides = (1, 1) + self.strides + + outputs = tf.compat.v1.nn.separable_conv2d( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=self.dilation_rate, + data_format=conv_utils.convert_data_format(self.data_format, ndim=4), + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + config = super(QSeparableConv2D, self).get_config() + config.update( + { + "act_min_value": self.act_min_value, + "act_max_value": self.act_max_value, + "weight_min_value": self.weight_min_value, + "weight_max_value": self.weight_max_value, + "granularity": self.granularity, + "quant_status": self.quant_status, + "quant_mode": self.quant_mode, + "quant_T": "s8" if self.quant_T == tf.qint8 else "u8", + "quant_round_mode": self.quant_round_mode, + "quant_narrow_range": self.quant_narrow_range, + "quant_axis": self.quant_axis, + } + ) + + return config + + +def initialize_int8_separable_conv2d(fp32_layer, q_config): kwargs = fp32_layer.get_config() if "name" in kwargs: @@ -203,7 +434,7 @@ def initialize_int8_separable_conv2d(fp32_layer): depthwise_constraint=fp32_layer.depthwise_constraint, pointwise_constraint=fp32_layer.pointwise_constraint, bias_constraint=fp32_layer.bias_constraint, - min_value=fp32_layer.min_value, - max_value=fp32_layer.max_value, + quant_T=q_config["T"], + granularity=q_config["granularity"], **kwargs ) diff --git a/neural_compressor/tensorflow/quantization/algorithm_entry.py b/neural_compressor/tensorflow/quantization/algorithm_entry.py index a0baf490390..4b40a2f39a1 100644 --- a/neural_compressor/tensorflow/quantization/algorithm_entry.py +++ b/neural_compressor/tensorflow/quantization/algorithm_entry.py @@ -19,7 +19,7 @@ from neural_compressor.common.utils import SMOOTH_QUANT, STATIC_QUANT from neural_compressor.tensorflow.algorithms import KerasAdaptor, Tensorflow_ITEXAdaptor, TensorFlowAdaptor from neural_compressor.tensorflow.quantization.config import SmoothQuantConfig -from neural_compressor.tensorflow.utils import BaseModel, KerasModel, TFConfig, register_algo +from neural_compressor.tensorflow.utils import BaseModel, KerasModel, TFConfig, register_algo, valid_keras_format @register_algo(name=STATIC_QUANT) @@ -41,6 +41,7 @@ def static_quant_entry( q_model: the quantized model. """ if isinstance(model, KerasModel): + assert valid_keras_format(model.model), "Only Sequential or Functional models are supported now." framework = KerasAdaptor elif TFConfig.global_config["backend"] == "itex": framework = Tensorflow_ITEXAdaptor diff --git a/neural_compressor/tensorflow/utils/__init__.py b/neural_compressor/tensorflow/utils/__init__.py index deb15140c92..65dbabd2270 100644 --- a/neural_compressor/tensorflow/utils/__init__.py +++ b/neural_compressor/tensorflow/utils/__init__.py @@ -55,4 +55,5 @@ Statistics, CaptureOutputToFile, LazyImport, + valid_keras_format, ) diff --git a/neural_compressor/tensorflow/utils/utility.py b/neural_compressor/tensorflow/utils/utility.py index ed1fc88aee8..886dcffc234 100644 --- a/neural_compressor/tensorflow/utils/utility.py +++ b/neural_compressor/tensorflow/utils/utility.py @@ -254,6 +254,16 @@ def wrapper(*args, **kw): return decorator +def valid_keras_format(model): + """Check if the input model is Sequential or Functional model.""" + import keras + + if isinstance(model, keras.src.models.Sequential) or isinstance(model, keras.src.models.Functional): + return True + + return False + + @singleton class CpuInfo(object): """Get CPU Info.""" diff --git a/requirements_tf.txt b/requirements_tf.txt index da1544d2939..f8075c2a068 100644 --- a/requirements_tf.txt +++ b/requirements_tf.txt @@ -3,4 +3,4 @@ psutil py-cpuinfo pydantic pyyaml -tensorflow<=2.15.1 +tensorflow diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 3a2b1f289ec..6e617cf0584 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -18,7 +18,6 @@ import math import os import shutil -import time import unittest import keras @@ -26,6 +25,7 @@ import tensorflow as tf from neural_compressor.common import Logger +from neural_compressor.tensorflow.utils import version1_gte_version2 logger = Logger().get_logger() @@ -69,7 +69,10 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model") + if version1_gte_version2(tf.__version__, "2.16.1"): + model.save("baseline_model.keras") + else: + model.save("baseline_model") class Dataset(object): @@ -113,10 +116,16 @@ class TestTF3xNewApi(unittest.TestCase): def setUpClass(self): build_model() os.environ["ITEX_ONEDNN_GRAPH"] = "1" + self.fp32_model_path = ( + "baseline_model.keras" if version1_gte_version2(tf.__version__, "2.16.1") else "baseline_model" + ) @classmethod def tearDownClass(self): - shutil.rmtree("baseline_model", ignore_errors=True) + if self.fp32_model_path.endswith(".keras"): + os.remove(self.fp32_model_path) + else: + shutil.rmtree(self.fp32_model_path, ignore_errors=True) os.environ["ITEX_ONEDNN_GRAPH"] = "0" def test_static_quant_from_dict_default(self): @@ -125,7 +134,7 @@ def test_static_quant_from_dict_default(self): from neural_compressor.tensorflow.keras import get_default_static_quant_config calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) @@ -152,7 +161,7 @@ def test_static_quant_from_dict_beginner(self): } } calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -168,7 +177,7 @@ def test_static_quant_from_class_default(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) quant_config = StaticQuantConfig() qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -185,7 +194,7 @@ def test_static_quant_from_class_beginner(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) quant_config = StaticQuantConfig( weight_dtype="int8", weight_sym=True, @@ -208,7 +217,7 @@ def test_static_quant_from_dict_advance(self): from neural_compressor.tensorflow import quantize_model calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) quant_config = { "static_quant": { "global": { @@ -254,7 +263,7 @@ def test_static_quant_from_class_advance(self): ) quant_config.set_local("dense", dense_config) # get model and quantize - fp32_model = keras.models.load_model("baseline_model") + fp32_model = keras.models.load_model(self.fp32_model_path) qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) diff --git a/test/3x/tensorflow/keras/test_model_wrappers.py b/test/3x/tensorflow/keras/test_model_wrappers.py index 37a260f86ef..b9cb3eecfd0 100644 --- a/test/3x/tensorflow/keras/test_model_wrappers.py +++ b/test/3x/tensorflow/keras/test_model_wrappers.py @@ -2,7 +2,6 @@ import os import platform -import shutil import unittest import numpy as np @@ -51,10 +50,10 @@ def setUpClass(self): @classmethod def tearDownClass(self): - shutil.rmtree("simple_model.h5", ignore_errors=True) - shutil.rmtree("keras_model.h5", ignore_errors=True) - shutil.rmtree("simple_model.keras", ignore_errors=True) - shutil.rmtree("keras_model.keras", ignore_errors=True) + os.remove("simple_model.h5") + os.remove("keras_model.h5") + os.remove("simple_model.keras") + os.remove("keras_model.keras") def test_keras_h5_model(self): if parse_version(tf.version.VERSION) < parse_version("2.3.0"): diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_fusion_newapi.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_fusion_newapi.py index 5464ee97fee..ec09ce9981a 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_fusion_newapi.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_fusion_newapi.py @@ -19,7 +19,7 @@ StripUnusedNodesOptimizer, ) from neural_compressor.tensorflow.quantization.utils.quantize_graph.qdq.optimize_qdq import OptimizeQDQGraph -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestConvBiasAddAddReluFusion(unittest.TestCase): @@ -173,7 +173,11 @@ def test_conv_biasadd_relu6_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -217,7 +221,11 @@ def test_conv_biasadd_swishf32_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) @function.Defun(tf.float32, func_name="swish_f32") def swish_f32(x): @@ -356,14 +364,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_1(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -416,14 +432,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_2(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -470,7 +494,11 @@ def test_conv_fusion_with_last_matmul(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -527,7 +555,11 @@ def test_conv_fusion_with_last_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -586,7 +618,11 @@ def test_conv_fusion_with_max_pooling(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(pooling, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - biasadd = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + biasadd = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = biasadd.name.split(":")[0] with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_requantize_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_requantize_fusion.py index a8bfcd7d31d..db29ab7ed50 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_requantize_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_conv_requantize_fusion.py @@ -11,7 +11,7 @@ from tensorflow.compat.v1 import graph_util from neural_compressor.tensorflow.algorithms.static_quant.tensorflow import TensorflowQuery -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestConvRequantizedFusionNewAPI(unittest.TestCase): @@ -25,7 +25,11 @@ def test_conv_biasadd_relu6_fusion(self): "weight0", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -367,7 +371,11 @@ def test_conv_add_add_fusion(self): "weight12", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) add = normed + tf.constant(np.random.randn(16), dtype=tf.float32) relu6 = tf.nn.relu6(add, name="op_to_store") diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_depthwiseconv_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_depthwiseconv_fusion.py index be449305f3e..dd1e2fc1212 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_depthwiseconv_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_depthwiseconv_fusion.py @@ -11,7 +11,7 @@ from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 from tensorflow.python.framework import dtypes, tensor_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 def build_Conv2dBiasAddAddRelu6MulMul(): @@ -140,7 +140,11 @@ def test_depthwiseconv_biasadd_fusion(self): ) conv = tf.nn.depthwise_conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -211,7 +215,11 @@ def test_depthwiseconv_biasadd_leakyrelu_fusion(self): ) conv = tf.nn.depthwise_conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) leakyrelu = tf.nn.leaky_relu(normed) out_name = leakyrelu.name.split(":")[0] diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_fuse_pad_conv_fp32.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_fuse_pad_conv_fp32.py index bbc23762796..a1afebc2753 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_fuse_pad_conv_fp32.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_fuse_pad_conv_fp32.py @@ -5,7 +5,7 @@ import yaml from tensorflow.compat.v1 import graph_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestFoldPadConv(unittest.TestCase): @@ -18,7 +18,11 @@ def test_fold_pad_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -62,7 +66,11 @@ def test_fold_non_const_pad_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_bn_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_bn_fusion.py index 169fcc1785e..4c1d18ca4b0 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_bn_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_bn_fusion.py @@ -13,7 +13,7 @@ from tensorflow.python.framework import dtypes from neural_compressor.common import logger -from neural_compressor.tensorflow.utils import CpuInfo, disable_random +from neural_compressor.tensorflow.utils import CpuInfo, disable_random, version1_gte_version2 class TestTensorflowQdqConvFusion(unittest.TestCase): @@ -24,10 +24,18 @@ def test_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): conv_weights = tf.compat.v1.get_variable( "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x) + ) relu = tf.nn.relu(normed_0, name="op_to_store_0") conv = tf.compat.v1.nn.depthwise_conv2d_native(relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed_1, name="op_to_store_1") out_name = relu6.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -92,10 +100,18 @@ def test_training_bn_relu_depthwiseconv_biasadd_relu6_fusion(self): conv_weights = tf.compat.v1.get_variable( "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x, training=True) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x, training=True) + ) relu = tf.nn.relu(normed_0, name="op_to_store_0") conv = tf.compat.v1.nn.depthwise_conv2d_native(relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed_1, name="op_to_store_1") out_name = relu6.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -147,10 +163,18 @@ def test_bn_leakyrelu_conv_biasadd_relu(self): conv_weights = tf.compat.v1.get_variable( "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x) + ) leaky_relu = tf.nn.leaky_relu(normed_0, alpha=0.3, name="op_to_store_0") conv = tf.nn.conv2d(leaky_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed_1, name="op_to_store_1") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -217,10 +241,18 @@ def test_bn_relu_conv_biasadd_relu(self): conv_weights = tf.compat.v1.get_variable( "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x) + ) relu_0 = tf.nn.relu(normed_0, name="op_to_store_0") conv = tf.nn.conv2d(relu_0, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu_1 = tf.nn.relu(normed_1, name="op_to_store_1") out_name = relu_1.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -287,10 +319,18 @@ def test_bn_performance_only_false(self): conv_weights = tf.compat.v1.get_variable( "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x) + ) relu_0 = tf.nn.relu(normed_0, name="op_to_store_0") conv = tf.nn.conv2d(relu_0, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu_1 = tf.nn.relu6(normed_1, name="op_to_store_1") out_name = relu_1.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -343,10 +383,18 @@ def test_bnex_performance_only_false(self): conv_weights_0 = tf.compat.v1.get_variable( "weight_0", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) - normed_0 = tf.compat.v1.layers.batch_normalization(x) + normed_0 = ( + tf.keras.layers.BatchNormalization()(x) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(x) + ) relu_0 = tf.nn.relu(normed_0, name="op_to_store_0") conv_0 = tf.nn.conv2d(relu_0, conv_weights_0, strides=[1, 2, 2, 1], padding="VALID") - normed_1 = tf.compat.v1.layers.batch_normalization(conv_0) + normed_1 = ( + tf.keras.layers.BatchNormalization()(conv_0) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv_0) + ) conv_weights_1 = tf.compat.v1.get_variable( "weight_1", [5, 5, 16, 2], initializer=tf.compat.v1.random_normal_initializer() ) diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_concat_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_concat_fusion.py index aae9da02cb0..5f089532759 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_concat_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_concat_fusion.py @@ -12,7 +12,7 @@ QuantizeGraphForIntel, ) from neural_compressor.tensorflow.quantization.utils.utility import read_graph -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestTensorflowQdqConcatFusion(unittest.TestCase): @@ -80,7 +80,11 @@ def test_concat_with_different_input_type(self): sqrt = tf.math.sqrt(x) relu_sqrt = tf.nn.relu(sqrt) conv = tf.nn.conv2d(relu_sqrt, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) conv1 = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_conv_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_conv_fusion.py index bf437c78940..895ee8471fe 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_conv_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_conv_fusion.py @@ -11,7 +11,7 @@ from tensorflow.compat.v1 import graph_util from tensorflow.python.framework import function -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestTensorflowQdqConvFusion(unittest.TestCase): @@ -25,7 +25,11 @@ def test_fold_pad_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -116,7 +120,11 @@ def test_conv_biasadd_relu6_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -162,7 +170,11 @@ def test_conv_biasadd_swishf32_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) @function.Defun(tf.float32, func_name="swish_f32") def swish_f32(x): @@ -304,14 +316,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_1(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -363,7 +383,11 @@ def test_conv_fusion_with_last_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -424,7 +448,11 @@ def test_conv_fusion_with_max_pooling(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(pooling, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - biasadd = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + biasadd = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = biasadd.name.split(":")[0] with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) @@ -471,7 +499,11 @@ def test_conv_biasadd_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -515,7 +547,11 @@ def test_depthwiseconv_biasadd_fusion(self): ) conv = tf.nn.depthwise_conv2d(top_relu, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -557,7 +593,11 @@ def test_conv_biasadd_relu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") @@ -601,7 +641,11 @@ def test_conv_biasadd_leakyrelu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) leaky_relu = tf.nn.leaky_relu(normed, name="op_to_store") @@ -645,7 +689,11 @@ def test_depthwiseconv_biasadd_relu6_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.compat.v1.nn.depthwise_conv2d_native(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -689,7 +737,11 @@ def test_depthwiseconv_biasadd_relu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.compat.v1.nn.depthwise_conv2d_native(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -790,7 +842,11 @@ def test_conv_fusion_with_last_matmul(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -897,7 +953,11 @@ def test_depthwiseconv_biasadd_leakyrelu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.compat.v1.nn.depthwise_conv2d_native(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) leaky_relu = tf.nn.leaky_relu(normed, name="op_to_store") @@ -944,14 +1004,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_2(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -997,7 +1065,11 @@ def test_conv_biasadd_elu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) elu = tf.nn.elu(normed, name="op_to_store") @@ -1041,7 +1113,11 @@ def test_conv_biasadd_sigmoid_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) sigmoid = tf.math.sigmoid(normed, name="op_to_store") diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_depthwiseconv_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_depthwiseconv_fusion.py index 7c3d0e8e0a1..99b71cf1d56 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_depthwiseconv_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_depthwiseconv_fusion.py @@ -11,7 +11,7 @@ from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 from tensorflow.python.framework import dtypes, tensor_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 def build_conv2d_biasadd_add_relu6_mul_mul(): @@ -184,7 +184,11 @@ def test_depthwiseconv2d_biasadd_fusion(self): ) conv = tf.nn.depthwise_conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -260,7 +264,11 @@ def test_depthwiseconv2d_biasadd_leakyrelu_fusion(self): ) conv = tf.nn.depthwise_conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) leakyrelu = tf.nn.leaky_relu(normed) out_name = leakyrelu.name.split(":")[0] diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_new_conv_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_new_conv_fusion.py index cec4608f775..ece8a5ddfab 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_new_conv_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_new_conv_fusion.py @@ -10,7 +10,7 @@ from tensorflow.compat.v1 import graph_util from tensorflow.python.framework import function -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestTensorflowNewQdqConvFusion(unittest.TestCase): @@ -22,7 +22,11 @@ def test_conv_biasadd_add_leakyrelu_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) conv2_weights = tf.compat.v1.get_variable( "weight_conv2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) diff --git a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_pooling_fusion.py b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_pooling_fusion.py index ecc7bfd3fc1..65e577b9e93 100644 --- a/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_pooling_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/newapi/test_graph_qdq_pooling_fusion.py @@ -10,7 +10,7 @@ from tensorflow.compat.v1 import graph_util from tensorflow.python.framework import dtypes -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestGraphQDQPoolingFusion(unittest.TestCase): @@ -23,7 +23,11 @@ def test_qdq_maxpool_fusion(self): conv_bias = tf.compat.v1.get_variable("bias", [1], initializer=tf.compat.v1.random_normal_initializer()) x = tf.nn.relu(x) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) relu2 = tf.nn.relu(relu) @@ -65,7 +69,11 @@ def test_qdq_avgpool_fusion(self): conv_bias = tf.compat.v1.get_variable("bias", [1], initializer=tf.compat.v1.random_normal_initializer()) x = tf.nn.relu(x) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) relu2 = tf.nn.relu(relu) diff --git a/test/3x/tensorflow/quantization/ptq/test_bias_correction.py b/test/3x/tensorflow/quantization/ptq/test_bias_correction.py index a093eb02cc9..ebb2fac9d1f 100644 --- a/test/3x/tensorflow/quantization/ptq/test_bias_correction.py +++ b/test/3x/tensorflow/quantization/ptq/test_bias_correction.py @@ -1,6 +1,6 @@ -import os import unittest +import numpy as np import tensorflow as tf from tensorflow.compat.v1 import graph_util @@ -11,6 +11,7 @@ ) from neural_compressor.tensorflow.quantization.utils.quantize_graph_common import QuantizeGraphHelper from neural_compressor.tensorflow.quantization.utils.transform_graph.bias_correction import BiasCorrection +from neural_compressor.tensorflow.utils import version1_gte_version2 class TestBiasCorrection(unittest.TestCase): @@ -20,10 +21,19 @@ def test_bias_correction(self): if tf.version.VERSION <= "2.1.0": x = tf.nn.relu(x) - conv_weights = tf.compat.v1.get_variable( - "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() + + conv_weights = ( + tf.Variable(np.random.rand(3, 3, 3, 32).tolist(), name="weight") + if version1_gte_version2(tf.version.VERSION, "2.16.1") + else tf.compat.v1.get_variable( + "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() + ) + ) + conv_bias = ( + tf.Variable(np.random.rand(32).tolist(), name="bias") + if version1_gte_version2(tf.version.VERSION, "2.16.1") + else tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) ) - conv_bias = tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) conv1 = tf.nn.conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="SAME") conv_bias = tf.nn.bias_add(conv1, conv_bias) relu = tf.nn.relu(conv_bias, name="Relu_1") diff --git a/test/3x/tensorflow/quantization/ptq/test_data_pipline.py b/test/3x/tensorflow/quantization/ptq/test_data_pipline.py index 3faf3089e22..f7e81580816 100644 --- a/test/3x/tensorflow/quantization/ptq/test_data_pipline.py +++ b/test/3x/tensorflow/quantization/ptq/test_data_pipline.py @@ -9,6 +9,7 @@ from neural_compressor.tensorflow.quantization.utils.quantize_graph_common import QuantizeGraphHelper from neural_compressor.tensorflow.quantization.utils.utility import get_tensor_by_name, iterator_sess_run +from neural_compressor.tensorflow.utils import version1_gte_version2 class TestDataPipelineConvert(unittest.TestCase): @@ -20,10 +21,18 @@ def test_data_pipeline(self): ds_iterator = tf_dataset.make_initializable_iterator() iter_tensors = ds_iterator.get_next() - conv_weights = tf.compat.v1.get_variable( - "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() + conv_weights = ( + tf.Variable(np.random.rand(3, 3, 3, 32).tolist(), name="weight") + if version1_gte_version2(tf.version.VERSION, "2.16.1") + else tf.compat.v1.get_variable( + "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() + ) + ) + conv_bias = ( + tf.Variable(np.random.rand(32).tolist(), name="bias") + if version1_gte_version2(tf.version.VERSION, "2.16.1") + else tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) ) - conv_bias = tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) conv1 = tf.nn.conv2d(iter_tensors, conv_weights, strides=[1, 1, 1, 1], padding="SAME") conv_bias = tf.math.add(conv1, conv_bias) relu = tf.nn.relu(conv_bias, name="Relu_1") diff --git a/test/3x/tensorflow/quantization/ptq/test_fold_batch_norm.py b/test/3x/tensorflow/quantization/ptq/test_fold_batch_norm.py index 19d60c1e950..09008209f0c 100644 --- a/test/3x/tensorflow/quantization/ptq/test_fold_batch_norm.py +++ b/test/3x/tensorflow/quantization/ptq/test_fold_batch_norm.py @@ -9,31 +9,39 @@ FoldBatchNormNodesOptimizer, ) from neural_compressor.tensorflow.quantization.utils.quantize_graph_common import QuantizeGraphHelper +from neural_compressor.tensorflow.utils import version1_gte_version2 class TestFoldBatchnorm(unittest.TestCase): - tf.compat.v1.disable_eager_execution() - x = tf.compat.v1.placeholder(tf.float32, [1, 224, 224, 3], name="input") - conv_weights = tf.compat.v1.get_variable( - "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() - ) - conv_bias = tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) - beta = tf.compat.v1.get_variable(name="beta", shape=[32], initializer=tf.compat.v1.random_normal_initializer()) - gamma = tf.compat.v1.get_variable(name="gamma", shape=[32], initializer=tf.compat.v1.random_normal_initializer()) - conv1 = tf.nn.conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="SAME") - conv_bias = tf.nn.bias_add(conv1, conv_bias) - normed = tf.compat.v1.layers.batch_normalization(conv_bias) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - output_graph_def = graph_util.convert_variables_to_constants( - sess=sess, input_graph_def=sess.graph_def, output_node_names=[normed.name.split(":")[0]] + if not version1_gte_version2(tf.__version__, "2.16.1"): + tf.compat.v1.disable_eager_execution() + x = tf.compat.v1.placeholder(tf.float32, [1, 224, 224, 3], name="input") + conv_weights = tf.compat.v1.get_variable( + "weight", [3, 3, 3, 32], initializer=tf.compat.v1.random_normal_initializer() ) - output_graph_def = QuantizeGraphHelper.remove_training_nodes( - output_graph_def, protected_nodes=[normed.name.split(":")[0]] + conv_bias = tf.compat.v1.get_variable("bias", [32], initializer=tf.compat.v1.random_normal_initializer()) + beta = tf.compat.v1.get_variable(name="beta", shape=[32], initializer=tf.compat.v1.random_normal_initializer()) + gamma = tf.compat.v1.get_variable( + name="gamma", shape=[32], initializer=tf.compat.v1.random_normal_initializer() ) - graph_def = copy.deepcopy(output_graph_def) - fold_graph_def = FoldBatchNormNodesOptimizer(output_graph_def).do_transformation() + conv1 = tf.nn.conv2d(x, conv_weights, strides=[1, 1, 1, 1], padding="SAME") + conv_bias = tf.nn.bias_add(conv1, conv_bias) + normed = tf.compat.v1.layers.batch_normalization(conv_bias) + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + output_graph_def = graph_util.convert_variables_to_constants( + sess=sess, input_graph_def=sess.graph_def, output_node_names=[normed.name.split(":")[0]] + ) + output_graph_def = QuantizeGraphHelper.remove_training_nodes( + output_graph_def, protected_nodes=[normed.name.split(":")[0]] + ) + graph_def = copy.deepcopy(output_graph_def) + fold_graph_def = FoldBatchNormNodesOptimizer(output_graph_def).do_transformation() + @unittest.skipIf( + version1_gte_version2(tf.version.VERSION, "2.16.1"), + "The TF BN is deleted after 2.16.1 while the fusion of Keras BN is not supported now", + ) def test_fold_output_values(self): input_data = np.random.randn(1, 224, 224, 3) graph = tf.compat.v1.Graph() @@ -57,6 +65,10 @@ def test_fold_output_values(self): assert np.allclose(y, y_fold, rtol=1e-05, atol=1e-05) + @unittest.skipIf( + version1_gte_version2(tf.version.VERSION, "2.16.1"), + "The TF BN is deleted after 2.16.1 while the fusion of Keras BN is not supported now", + ) def test_do_transform(self): for node in self.fold_graph_def.node: assert node.op not in ["FusedBatchNormV3"] diff --git a/test/3x/tensorflow/quantization/ptq/test_get_estimator_graph.py b/test/3x/tensorflow/quantization/ptq/test_get_estimator_graph.py index 62a57f001ea..b538c34a43d 100644 --- a/test/3x/tensorflow/quantization/ptq/test_get_estimator_graph.py +++ b/test/3x/tensorflow/quantization/ptq/test_get_estimator_graph.py @@ -8,16 +8,23 @@ import tensorflow as tf from neural_compressor.tensorflow.quantization.utils.utility import get_estimator_graph +from neural_compressor.tensorflow.utils import version1_gte_version2 class TestEstimatorGraphConvert(unittest.TestCase): @classmethod def setUpClass(self): + if version1_gte_version2(tf.version.VERSION, "2.16.1"): + return + self.dst_path = "/tmp/.neural_compressor/train.csv" self.titanic_file = tf.keras.utils.get_file( self.dst_path, "https://storage.googleapis.com/tf-datasets/titanic/train.csv" ) + @unittest.skipIf( + version1_gte_version2(tf.version.VERSION, "2.16.1"), "The estimator APIs are deleted after TF2.16.1" + ) def test_get_estimator_graph(self): def train_input_fn(): titanic = tf.data.experimental.make_csv_dataset(self.titanic_file, batch_size=32, label_name="survived") diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_concat.py b/test/3x/tensorflow/quantization/ptq/test_graph_concat.py index f6254bb5139..e5ca34bd2a9 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_concat.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_concat.py @@ -69,7 +69,7 @@ def test_tensorflow_concat_quantization(self): @disable_random() def test_concat_with_different_input_type(self): from neural_compressor.tensorflow import quantize_model - from neural_compressor.tensorflow.utils import BaseDataLoader, DummyDataset + from neural_compressor.tensorflow.utils import BaseDataLoader, DummyDataset, version1_gte_version2 x = tf.compat.v1.placeholder(tf.float32, [1, 128, 128, 16], name="input") conv_weights = tf.compat.v1.get_variable( @@ -81,7 +81,11 @@ def test_concat_with_different_input_type(self): sqrt = tf.math.sqrt(x) relu_sqrt = tf.nn.relu(sqrt) conv = tf.nn.conv2d(relu_sqrt, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) conv1 = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_conv_as_output.py b/test/3x/tensorflow/quantization/ptq/test_graph_conv_as_output.py index 624dc27e531..3de83e3f13c 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_conv_as_output.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_conv_as_output.py @@ -1,5 +1,4 @@ import os -import shutil import unittest import numpy as np diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_conv_fusion.py b/test/3x/tensorflow/quantization/ptq/test_graph_conv_fusion.py index 2a40cec53f0..ca7e1e67c5c 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_conv_fusion.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_conv_fusion.py @@ -21,7 +21,7 @@ from neural_compressor.tensorflow.quantization.utils.quantize_graph.quantize_graph_for_intel_cpu import ( QuantizeGraphForIntel, ) -from neural_compressor.tensorflow.utils import BaseDataLoader, DummyDataset, disable_random +from neural_compressor.tensorflow.utils import BaseDataLoader, DummyDataset, disable_random, version1_gte_version2 class TestConvBiasAddAddReluFusion(unittest.TestCase): @@ -80,8 +80,11 @@ def test_depthwiseconv_biasadd_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.depthwise_conv2d(x_pad, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -109,6 +112,9 @@ def test_depthwiseconv_biasadd_fusion(self): if i.op == "QuantizedDepthwiseConv2DWithBias": found_conv_fusion = True break + if i.op == "QuantizedDepthwiseConv2D" and version1_gte_version2(tf.__version__, "2.16.1"): + found_conv_fusion = True + break self.assertEqual(found_conv_fusion, True) @@ -121,8 +127,11 @@ def test_depthwiseconv_biasadd_fusion_with_negative_input(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.depthwise_conv2d(x_pad, conv_weights, strides=[1, 1, 1, 1], padding="VALID") - - normed = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = normed.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -166,7 +175,11 @@ def test_conv_biasadd_relu6_fusion(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) relu6 = tf.nn.relu6(normed, name="op_to_store") @@ -294,14 +307,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_1(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2, name="op_to_store") + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -346,14 +367,22 @@ def test_conv_biasadd_addv2_relu_fallback_fusion_2(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) # relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2, name="op_to_store") + ) # relu2 = tf.nn.relu(normed2) add = tf.raw_ops.AddV2(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) @@ -398,7 +427,11 @@ def test_conv_fusion_with_last_matmul(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -452,7 +485,11 @@ def test_conv_fusion_with_last_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) relu = tf.nn.relu(normed) pooling = tf.nn.max_pool(relu, ksize=1, strides=[1, 2, 2, 1], padding="SAME") @@ -508,7 +545,11 @@ def test_conv_fusion_with_max_pooling(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(pooling, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - biasadd = tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + biasadd = ( + tf.keras.layers.BatchNormalization(name="op_to_store")(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv, name="op_to_store") + ) out_name = biasadd.name.split(":")[0] with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_meta_pass.py b/test/3x/tensorflow/quantization/ptq/test_graph_meta_pass.py index 3432d956342..9df7305098c 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_meta_pass.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_meta_pass.py @@ -8,7 +8,7 @@ import yaml from tensorflow.compat.v1 import graph_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestMetaPass(unittest.TestCase): @@ -20,7 +20,11 @@ def test_tensorflow_graph_meta_pass_with_different_mode(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) sq = tf.squeeze(relu, [0]) @@ -78,7 +82,11 @@ def test_tensorflow_graph_meta_pass_with_same_mode(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) sq = tf.squeeze(relu, [0]) @@ -87,7 +95,11 @@ def test_tensorflow_graph_meta_pass_with_same_mode(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(reshape, conv_weights2, strides=[1, 2, 2, 1], padding="VALID") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu6 = tf.nn.relu6(normed2, name="op_to_store") @@ -134,7 +146,11 @@ def test_tensorflow_graph_meta_with_reshape_only(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(top_relu, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) reshape = tf.reshape(relu, [1, 27, 27, 16]) @@ -142,7 +158,11 @@ def test_tensorflow_graph_meta_with_reshape_only(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(reshape, conv_weights2, strides=[1, 2, 2, 1], padding="VALID") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu6 = tf.nn.relu6(normed2, name="op_to_store") diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_pad_conv.py b/test/3x/tensorflow/quantization/ptq/test_graph_pad_conv.py index 278e851de1c..5a4fe4a702e 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_pad_conv.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_pad_conv.py @@ -5,7 +5,7 @@ import yaml from tensorflow.compat.v1 import graph_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestFoldPadConv(unittest.TestCase): @@ -18,7 +18,11 @@ def test_fold_pad_conv(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -61,7 +65,11 @@ def test_fold_pad_conv2(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) paddings2 = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]]) @@ -70,7 +78,11 @@ def test_fold_pad_conv2(self): "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(x_pad2, conv_weights2, strides=[1, 2, 2, 1], padding="VALID") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu2 = tf.nn.relu(normed2) add = tf.math.add(relu, relu2, name="op_to_store") out_name = add.name.split(":")[0] @@ -115,14 +127,22 @@ def test_fold_pad_conv3(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(x, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu2 = tf.nn.relu(normed2) add = tf.math.add(relu, relu2, name="op_to_store") out_name = add.name.split(":")[0] diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_post_cse_optimize.py b/test/3x/tensorflow/quantization/ptq/test_graph_post_cse_optimize.py index faa1def2c0f..b0f55d30434 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_post_cse_optimize.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_post_cse_optimize.py @@ -6,7 +6,7 @@ import yaml from tensorflow.compat.v1 import graph_util -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestPostCSEOptimizer(unittest.TestCase): @@ -29,14 +29,22 @@ def test_post_cse(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(z, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(z, conv_weights2, strides=[1, 2, 2, 1], padding="VALID") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu2 = tf.nn.relu(normed2) add = tf.math.add(relu, relu2, name="op_to_store") out_name = add.name.split(":")[0] @@ -87,14 +95,22 @@ def test_post_cse2(self): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(z, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(z, conv_weights2, strides=[1, 2, 2, 1], padding="VALID") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) relu2 = tf.nn.relu(normed2) add = tf.math.add(relu, relu2) ones_const = tf.constant(1, dtype=tf.float32) diff --git a/test/3x/tensorflow/quantization/ptq/test_graph_switch_optimizer.py b/test/3x/tensorflow/quantization/ptq/test_graph_switch_optimizer.py index 403aabcf527..022b2bd0ba2 100644 --- a/test/3x/tensorflow/quantization/ptq/test_graph_switch_optimizer.py +++ b/test/3x/tensorflow/quantization/ptq/test_graph_switch_optimizer.py @@ -7,7 +7,7 @@ from tensorflow.compat.v1 import graph_util from tensorflow.python.ops import control_flow_ops -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 class TestSwitchOptimizer(unittest.TestCase): @@ -21,7 +21,11 @@ def test_switch_optimizer(self): conv_weights = tf.constant(np.random.random((3, 3, 16, 16)).astype(np.float32), name="y") _, switch_true = control_flow_ops.switch(conv_weights, y) conv = tf.nn.conv2d(x_pad, switch_true, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -51,7 +55,11 @@ def test_switch_optimizer_with_const_boolean(self): conv_weights = tf.constant(np.random.random((3, 3, 16, 16)).astype(np.float32), name="y") _, switch_true = control_flow_ops.switch(conv_weights, y) conv = tf.nn.conv2d(x_pad, switch_true, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: @@ -87,7 +95,11 @@ def test_switch_optimizer_invalid(self): conv_weights = tf.constant(np.random.random((3, 3, 16, 16)).astype(np.float32), name="y") switch_false, _ = control_flow_ops.switch(conv_weights, y) conv = tf.nn.conv2d(x_pad, switch_false, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed, name="op_to_store") out_name = relu.name.split(":")[0] with tf.compat.v1.Session() as sess: diff --git a/test/3x/tensorflow/quantization/ptq/test_query_yaml.py b/test/3x/tensorflow/quantization/ptq/test_query_yaml.py index 68960c05b87..192386a0b31 100644 --- a/test/3x/tensorflow/quantization/ptq/test_query_yaml.py +++ b/test/3x/tensorflow/quantization/ptq/test_query_yaml.py @@ -10,7 +10,7 @@ import neural_compressor from neural_compressor.tensorflow.algorithms.static_quant.tensorflow import TensorflowQuery -from neural_compressor.tensorflow.utils import disable_random +from neural_compressor.tensorflow.utils import disable_random, version1_gte_version2 def build_fake_framework_yaml(): @@ -168,7 +168,11 @@ def test_grappler_cfg(self): x = tf.nn.relu(x) conv = tf.nn.conv2d(x, conv_weights, strides=[1, 2, 2, 1], padding="SAME", name="last") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) relu = tf.nn.relu(normed) relu2 = tf.nn.relu(relu) diff --git a/test/3x/tensorflow/quantization/test_smooth_quant.py b/test/3x/tensorflow/quantization/test_smooth_quant.py index 5c76eadb9cd..9766a709bb7 100644 --- a/test/3x/tensorflow/quantization/test_smooth_quant.py +++ b/test/3x/tensorflow/quantization/test_smooth_quant.py @@ -7,7 +7,7 @@ from neural_compressor.common import set_random_seed from neural_compressor.tensorflow import SmoothQuantConfig, StaticQuantConfig, get_default_sq_config, quantize_model -from neural_compressor.tensorflow.utils import DummyDataset, disable_random +from neural_compressor.tensorflow.utils import DummyDataset, disable_random, version1_gte_version2 def build_conv_graph(): @@ -20,13 +20,21 @@ def build_conv_graph(): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) + normed = ( + tf.keras.layers.BatchNormalization()(conv) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv) + ) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) + normed2 = ( + tf.keras.layers.BatchNormalization()(conv2) + if version1_gte_version2(tf.__version__, "2.16.1") + else tf.compat.v1.layers.batch_normalization(conv2) + ) add = tf.raw_ops.Add(x=normed, y=normed2, name="addv2") relu = tf.nn.relu(add) diff --git a/test/3x/tensorflow/test_autotune.py b/test/3x/tensorflow/test_autotune.py index 9c89f8cd5fc..646a86a1549 100644 --- a/test/3x/tensorflow/test_autotune.py +++ b/test/3x/tensorflow/test_autotune.py @@ -12,6 +12,7 @@ from neural_compressor.common import logger from neural_compressor.common.base_tuning import Evaluator, TuningConfig from neural_compressor.tensorflow.quantization import SmoothQuantConfig, StaticQuantConfig, autotune +from neural_compressor.tensorflow.utils import version1_gte_version2 def _create_evaluator_for_eval_fns(eval_fns: Optional[Union[Callable, Dict, List[Dict]]] = None) -> Evaluator: @@ -59,7 +60,10 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model") + if version1_gte_version2(tf.__version__, "2.16.1"): + tf.saved_model.save(model, "baseline_model") + else: + model.save("baseline_model") class Dataset(object):