From f21afbbdd18cd61627fc02e5b22ca242402bcfbf Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Fri, 29 Mar 2024 17:31:15 +0800 Subject: [PATCH] Refactor Keras PTQ Implementation (#1698) Signed-off-by: zehao-intel --- .../algorithms/static_quant/keras.py | 735 ++++++++++-------- .../tensorflow/keras/layers/__init__.py | 3 +- .../tensorflow/keras/layers/conv2d.py | 91 ++- .../tensorflow/keras/layers/dense.py | 70 +- .../keras/layers/depthwise_conv2d.py | 370 +++++---- .../keras/layers/layer_initializer.py | 33 + .../tensorflow/keras/layers/pool2d.py | 80 +- .../tensorflow/keras/layers/quantizer.py | 7 +- .../keras/layers/separable_conv2d.py | 363 +++++---- .../tensorflow/keras/quantization/config.py | 2 +- requirements_tf.txt | 2 +- test/3x/tensorflow/keras/test_config.py | 38 +- .../quantization/test_smooth_quant.py | 4 +- test/3x/tensorflow/test_autotune.py | 2 +- 14 files changed, 1049 insertions(+), 751 deletions(-) create mode 100644 neural_compressor/tensorflow/keras/layers/layer_initializer.py diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 442ba1b3d48..79ed5464a1f 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2023 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,17 +17,16 @@ import copy import json -import math import os from collections import OrderedDict, UserDict from typing import Callable, Dict -import keras import numpy as np import tensorflow as tf import yaml from neural_compressor.common import logger +from neural_compressor.common.utils import DEFAULT_WORKSPACE from neural_compressor.tensorflow.keras.layers import ( DeQuantize, FakeQuant, @@ -43,44 +42,44 @@ from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time -def _add_supported_quantized_objects(custom_objects): - """Map all the quantized objects.""" - custom_objects["Quantize"] = Quantize - custom_objects["DeQuantize"] = DeQuantize - custom_objects["FakeQuant"] = FakeQuant - custom_objects["QConv2D"] = QConv2D - custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QSeparableConv2D"] = QSeparableConv2D - custom_objects["QDense"] = QDense - custom_objects["QMaxPool2D"] = QMaxPool2D - custom_objects["QAvgPool2D"] = QAvgPool2D - custom_objects["QMaxPooling2D"] = QMaxPool2D - custom_objects["QAveragePooling2D"] = QAvgPool2D - return custom_objects - - class KerasAdaptor: """The keras class of framework adaptor layer.""" + supported_op = [ + "Conv2D", + "Dense", + "SeparableConv2D", + "DepthwiseConv2D", + "AveragePooling2D", + "MaxPooling2D", + "AvgPool2D", + "MaxPool2D", + ] + + custom_layers = { + "Quantize": Quantize, + "DeQuantize": DeQuantize, + "FakeQuant": FakeQuant, + "QConv2D": QConv2D, + "QDepthwiseConv2D": QDepthwiseConv2D, + "QSeparableConv2D": QSeparableConv2D, + "QDense": QDense, + "QMaxPool2D": QMaxPool2D, + "QAvgPool2D": QAvgPool2D, + "QMaxPooling2D": QMaxPool2D, + "QAveragePooling2D": QAvgPool2D, + } + def __init__(self, framework_specific_info): + """Initialize the KerasAdaptor class with framework specific information.""" self.framework_specific_info = framework_specific_info self.approach = deep_get(self.framework_specific_info, "approach", False) self.quantize_config = {"op_wise_config": {}} self.device = self.framework_specific_info["device"] self.backend = self.framework_specific_info["backend"] self.recipes = deep_get(self.framework_specific_info, "recipes", {}) - self.supported_op = [ - "Conv2D", - "Dense", - "SeparableConv2D", - "DepthwiseConv2D", - "AveragePooling2D", - "MaxPooling2D", - "AvgPool2D", - "MaxPool2D", - ] - - self.pre_optimized_object = None + + self.pre_optimized_model = None self.pre_optimizer_handle = None self.bf16_ops = [] self.fp32_ops = [] @@ -91,6 +90,10 @@ def __init__(self, framework_specific_info): self.callbacks = [] self.conv_format = {} + self.fold_conv = [] + if not os.path.exists(DEFAULT_WORKSPACE): + os.mkdir(DEFAULT_WORKSPACE) + self.tmp_dir = DEFAULT_WORKSPACE + "tmp_model" def _check_itex(self): """Check if the IntelĀ® Extension for TensorFlow has been installed.""" @@ -102,84 +105,74 @@ def _check_itex(self): "Please install it to run models on ITEX backend" ) - def tuning_cfg_to_fw(self, tuning_cfg): - """Parse tune_config and set framework variables.""" - self.quantize_config["calib_iteration"] = tuning_cfg["calib_iteration"] - self.quantize_config["device"] = self.device - self.quantize_config["advance"] = deep_get(tuning_cfg, "advance") - fp32_ops = [] - bf16_ops = [] - bf16_type = set(self.query_handler.get_op_types_by_precision(precision="bf16")) - dispatched_op_names = [j[0] for j in tuning_cfg["op"]] - invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names] - - for op_name in invalid_op_names: - self.quantize_config["op_wise_config"].pop(op_name) + def convert_bf16(self): + """Execute the BF16 conversion.""" + tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") + model = self.pre_optimized_model - for each_op_info in tuning_cfg["op"]: - op_name = each_op_info[0] + for layer in model.layers: + if layer.name in self.bf16_ops: + layer.dtype = "mixed_bfloat16" - if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": - if each_op_info[1] in bf16_type: - bf16_ops.append(op_name) - continue + model.save(self.tmp_dir) + converted_model = tf.keras.models.load_model(self.tmp_dir) + tf.keras.mixed_precision.set_global_policy("float32") - if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32": - if op_name in self.quantize_config["op_wise_config"]: - self.quantize_config["op_wise_config"].pop(op_name) - fp32_ops.append(op_name) - continue + return converted_model - is_perchannel = False - bit = None - if "weight" in tuning_cfg["op"][each_op_info]: - is_perchannel = tuning_cfg["op"][each_op_info]["weight"]["granularity"] == "per_channel" - # bit = tuning_cfg['op'][each_op_info]['weight']['bit'] - weight_bit = bit if bit else 7.0 - algorithm = tuning_cfg["op"][each_op_info]["activation"]["algorithm"] - is_asymmetric = False - if "activation" in tuning_cfg["op"][each_op_info]: - is_asymmetric = tuning_cfg["op"][each_op_info]["activation"]["scheme"] == "asym" - self.quantize_config["op_wise_config"][op_name] = (is_perchannel, algorithm, is_asymmetric, weight_bit) - self.bf16_ops = bf16_ops - if self.bf16_ops: - self.bf16_ops.pop(-1) - self.fp32_ops = fp32_ops + # (TODO) choose the properly quantize mode + def _check_quantize_mode(self, model): + """Check what quantize mode to use.""" + for layer in model.layers: + if "ReLU" in layer.__class__.__name__: + return "MIN_FIRST" + return "SCALED" - def _pre_optimize(self, model): - """Apply pre-optimization.""" - model = self._check_quantize_format(model) - model = self._fuse_bn(model) - return model + def _set_weights(self, qmodel, layer_weights): + """Set fp32 weights to qmodel.""" + for qlayer in qmodel.layers: + if qlayer.get_weights(): + if qlayer.name in layer_weights: + qlayer.set_weights(layer_weights[qlayer.name]) + else: + hit_layer = False + for sub_layer in qlayer.submodules: + if sub_layer.name in layer_weights: + qlayer.set_weights(layer_weights[sub_layer.name]) + hit_layer = True + break + if not hit_layer: + raise ValueError("Can not match the module weights....") + return qmodel def _check_quantize_format(self, model): """The function that checks format for conv ops.""" - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - fp32_layers = config["layers"] - name_op_map = {} - - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - name_op_map[layer["config"]["name"]] = layer - - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if layer["class_name"] in self.supported_op: - if "inbound_nodes" in layer: - check_layer = name_op_map[layer["inbound_nodes"][0][0][0]] + input_layer_dict = {} + layer_name_mapping = {} + + for layer in model.layers: + layer_name_mapping[layer.name] = layer + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in input_layer_dict: + input_layer_dict[layer_name] = [layer.name] else: - check_layer = fp32_layers[idx - 1] - if check_layer["class_name"] in ["Activation"] and check_layer["config"]["activation"] in ["relu"]: - self.conv_format[layer["config"]["name"]] = "u8" - else: - self.conv_format[layer["config"]["name"]] = "s8" - return model + input_layer_dict[layer_name].append(layer.name) + + for layer in model.layers: + if layer.__class__.__name__ in self.supported_op: + self.conv_format[layer.name] = "s8" + input_layer_names = input_layer_dict[layer.name] + for input_layer_name in input_layer_names: + check_layer = layer_name_mapping[input_layer_name] + if check_layer.__class__.__name__ == "Activation" and check_layer.activation.__name__ in ["relu"]: + self.conv_format[layer.name] = "u8" + break def _fuse_bn(self, model): """Fusing Batch Normalization.""" - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - fp32_layers = config["layers"] + fuse_bn_model = copy.deepcopy(model) + fp32_layers = fuse_bn_model.layers def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): assert conv_type in [ @@ -225,77 +218,82 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): bias = bias.reshape(-1) return [depth_weight, weight, bias] if conv_type == "SeparableConv2D" else [weight, bias] - node_map = {} - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if "inbound_nodes" in layer: - node_map[layer["name"]] = layer - fuse_layers = [] - fold_conv = [] - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if "inbound_nodes" in layer: - if layer["class_name"] in ["BatchNormalization"]: - bn_inbound_node = node_map[layer_config["name"]]["inbound_nodes"][0][0] - if bn_inbound_node[0] in self.conv_weights.keys(): - conv_weight = self.conv_weights[bn_inbound_node[0]] - conv_layer = node_map[bn_inbound_node[0]] - bn_weight = self.bn_weights[layer_config["name"]] - self.layer_weights[bn_inbound_node[0]] = fuse_conv_bn( - conv_weight, bn_weight, conv_layer["class_name"], layer["config"]["epsilon"] - ) - fold_conv.append(bn_inbound_node[0]) - else: - fuse_layers.append(layer) - elif len(layer["inbound_nodes"]): + for idx, layer in enumerate(fp32_layers): + if hasattr(layer, "_inbound_nodes"): + if layer.__class__.__name__ in ("BatchNormalization"): + for bn_inbound_node in layer._inbound_nodes: + inbound_layer = bn_inbound_node.inbound_layers + if inbound_layer.name in self.conv_weights.keys(): + conv_layer = inbound_layer + conv_weight = self.conv_weights[conv_layer.name] + bn_weight = self.bn_weights[layer.name] + + self.layer_weights[conv_layer.name] = fuse_conv_bn( + conv_weight, bn_weight, conv_layer.__class__.__name__, layer.epsilon + ) + self.fold_conv.append(conv_layer.name) + else: + fuse_layers.append(layer) + elif len(layer._inbound_nodes): new_bound_nodes = [] # OpLambda node will have different bound node - if layer["class_name"] in ["TFOpLambda", "SlicingOpLambda"]: + if layer.__class__.__name__ in ("TFOpLambda", "SlicingOpLambda"): fuse_layers.append(layer) else: - for bound_node in layer["inbound_nodes"][0]: - if bound_node[0] in self.bn_weights.keys(): - bn_inbound_node = node_map[bound_node[0]]["inbound_nodes"][0][0] - if bn_inbound_node[0] in self.conv_weights.keys(): - new_bound_nodes.append(bn_inbound_node) - else: - new_bound_nodes.append(bound_node) + for bound_node in layer._inbound_nodes: + inbound_layer = bound_node.inbound_layers + if ( + not isinstance(inbound_layer, list) + and inbound_layer.name in self.bn_weights.keys() + and inbound_layer._inbound_nodes[0].inbound_layers.name in self.conv_weights.keys() + ): + new_bound_nodes.append(bn_inbound_node) else: new_bound_nodes.append(bound_node) - layer["inbound_nodes"] = [new_bound_nodes] + + layer._inbound_nodes.clear() + for bound_node in new_bound_nodes: + layer._inbound_nodes.append(bound_node) fuse_layers.append(layer) else: fuse_layers.append(layer) else: if ( idx > 0 - and layer["class_name"] in ["BatchNormalization"] - and fp32_layers[idx - 1]["class_name"] in ["Conv2D"] + and layer.__class__.__name__ == "BatchNormalization" + and fp32_layers[idx - 1].__class__.__name__ == "Conv2D" ): - conv_name = fp32_layers[idx - 1]["config"]["name"] + conv_name = fp32_layers[idx - 1].name conv_weight = self.conv_weights[conv_name] - bn_weight = self.bn_weights[layer_config["name"]] - conv_type = fp32_layers[idx - 1]["class_name"] - self.layer_weights[conv_name] = fuse_conv_bn( - conv_weight, bn_weight, conv_type, layer["config"]["epsilon"] - ) - fold_conv.append(conv_name) + bn_weight = self.bn_weights[layer.name] + conv_type = fp32_layers[idx - 1].__class__.__name__ + + self.layer_weights[conv_name] = fuse_conv_bn(conv_weight, bn_weight, conv_type, layer.epsilon) + self.fold_conv.append(conv_name) else: fuse_layers.append(layer) - # bn folding will have a shift bias for idx, layer in enumerate(fuse_layers): - layer_config = layer["config"] if ( - layer["class_name"] in ["Conv2D", "DepthwiseConv2D", "SeparableConv2D"] - and layer_config["name"] in fold_conv + layer.__class__.__name__ in ("Conv2D", "DepthwiseConv2D", "SeparableConv2D") + and layer.name in self.fold_conv ): - layer_config["use_bias"] = True + conv_config = layer.get_config() + conv_config["use_bias"] = True + conv_layer = type(layer).from_config(conv_config) + for node in layer._outbound_nodes: + conv_layer._outbound_nodes.append(node) + fuse_layers[idx] = conv_layer + + bn_surgery = KerasSurgery(model) + bn_fused_model = bn_surgery.fuse_bn_layers(fuse_layers, self.conv_weights.keys()) + bn_fused_model = self._set_weights(bn_fused_model, self.layer_weights) - json_model["config"]["layers"] = fuse_layers - fused_model = self._restore_model_from_json(json_model) - return fused_model + bn_fused_model.save(self.tmp_dir) + bn_fused_model = tf.keras.models.load_model(self.tmp_dir) + + return bn_fused_model @dump_elapsed_time("Pass quantize model") def quantize(self, quant_config, model, dataloader, iteration, q_func=None): @@ -318,8 +316,9 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): converted_model = self.convert_bf16() return converted_model - # if self.backend == "itex": - # self._check_itex() + if self.backend == "itex": + self._check_itex() + logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) @@ -334,48 +333,46 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): ) ) - q_layers = [] - self.inbound_nodes_map = {} - for idx, layer in enumerate(copy.deepcopy(self.fp32_layers)): - layer_config = layer["config"] - if ( - layer["class_name"] in self.supported_op - and layer["config"]["name"] in self.quantize_config["op_wise_config"] - ): - op_config = self.quantize_config["op_wise_config"][layer["config"]["name"]] + fq_layers_dict = {} + fq_output_layers = {} + for idx, layer in enumerate(self.pre_optimized_model.layers): + if layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: + op_config = self.quantize_config["op_wise_config"][layer.name] mode = "per_channel" if op_config[0] else "per_tensor" fake_q_name = "fake_quant_" + str(idx) - fake_q_layer = { - "class_name": "FakeQuant", - "name": fake_q_name, - "T": self.conv_format[layer["config"]["name"]], - "config": {"mode": "per_tensor", "name": fake_q_name}, - } - if "inbound_nodes" in layer: - fake_q_layer["inbound_nodes"] = layer["inbound_nodes"] - layer["inbound_nodes"] = [[[fake_q_name, 0, 0, {}]]] - self.inbound_nodes_map[fake_q_name] = layer - - q_layers.append(fake_q_layer) - q_layers.append(layer) - else: - q_layers.append(layer) - - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) - json_model["config"]["layers"] = q_layers - quantized_model = self._restore_model_from_json(json_model) + fake_q_layer = FakeQuant(name=fake_q_name, T=self.conv_format[layer.name], mode="per_tensor") + fq_layers_dict[layer.name] = [fake_q_layer] + fq_output_layers[fake_q_layer.name] = layer.name + self.pre_optimized_model.save(self.tmp_dir) + + fq_surgery = KerasSurgery(self.pre_optimized_model) + calibration_model = fq_surgery.insert_quant_layers(fq_layers_dict) + calibration_model = self._set_weights(calibration_model, self.layer_weights) + + quantized_model = self._calibrate( + calibration_model, + dataloader, + self.quantize_config["calib_iteration"], + fq_output_layers, + ) - converted_model = self._calibrate(quantized_model, dataloader, self.quantize_config["calib_iteration"]) + return quantized_model - return converted_model + def _calibrate(self, model, dataloader, calib_interation, fq_output_layers): + """Apply calibration. - def _calibrate(self, model, dataloader, calib_interation): - """Apply calibration.""" + Args: + model (tf.keras.Model): The model inserted with FakeQuant layers for calibration. + dataloader(object): The calibration dataloader used to load quantization dataset. + iteration(int): The iteration of calibration. + fq_output_layers (dict): A dict mapping from names of FakeQuant layers to + names of their output layers. + """ # run eagerly to fetch the numpy min/max - model.compile(run_eagerly=True) results = {} + model.compile(run_eagerly=True) for idx, (inputs, labels) in enumerate(dataloader): - outputs = model.predict_on_batch(inputs) + _ = model.predict_on_batch(inputs) json_model = copy.deepcopy(json.loads(model.to_json())) config = json_model["config"] layers = config["layers"] @@ -383,6 +380,8 @@ def _calibrate(self, model, dataloader, calib_interation): if layer["class_name"] == "FakeQuant": min_value = layer["config"]["min_value"] max_value = layer["config"]["max_value"] + assert min_value < max_value, "The min value must be lower than the max value in quantization." + if layer["config"]["name"] not in results: results[layer["config"]["name"]] = {"min": [min_value], "max": [max_value]} else: @@ -391,143 +390,62 @@ def _calibrate(self, model, dataloader, calib_interation): if idx + 1 == calib_interation: break - # insert the calibrated min/max to Q/DQ - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - layers = config["layers"] - q_layers = [] - # quantize_mode = self._check_quantize_mode(json_model) - inbound_reverse_map = {} - for idx, layer in enumerate(layers): - layer_config = copy.deepcopy(layer["config"]) - if layer["class_name"] == "FakeQuant": - min_value = min(results[layer["config"]["name"]]["min"]) - max_value = max(results[layer["config"]["name"]]["max"]) - quantize_layer = { - "class_name": "Quantize", - "name": "quantize_" + str(idx), - "config": { - "min_range": min_value, - "max_range": max_value, - "T": layer_config["T"], - "name": "quantize_" + str(idx), - }, - } - dequantize_layer = { - "class_name": "DeQuantize", - "name": "dequantize_" + str(idx), - "config": { - "min_range": min_value, - "max_range": max_value, - # 'mode': quantize_mode, - "name": "dequantize_" + str(idx), - }, - } - if "inbound_nodes" in layer: - quantize_layer["inbound_nodes"] = layer["inbound_nodes"] - dequantize_layer["inbound_nodes"] = [[["quantize_" + str(idx), 0, 0, {}]]] - # find the conv/dense layer from fake quant map and - # change the conv/dense node inbound to dequantize - layer_name = self.inbound_nodes_map[layer["name"]]["name"] - inbound_reverse_map[layer_name] = [[["dequantize_" + str(idx), 0, 0, {}]]] - - q_layers.append(quantize_layer) - q_layers.append(dequantize_layer) - elif ( - layer["class_name"] in self.supported_op - and layer["config"]["name"] in self.quantize_config["op_wise_config"] - ): + qdq_layer_nums = 0 + qdq_layers_dict = {} + quantized_layers_dict = {} + for idx, layer in enumerate(model.layers): + if layer.__class__.__name__ == "FakeQuant": + min_value = min(results[layer.name]["min"]) + max_value = max(results[layer.name]["max"]) + + quantize_layer = Quantize( + name="quantize_" + str(qdq_layer_nums), + min_range=min_value, + max_range=max_value, + T=layer.T, + ) + dequantize_layer = DeQuantize( + name="dequantize_" + str(qdq_layer_nums), + min_range=min_value, + max_range=max_value, + ) + + qdq_layer_nums += 1 + output_layer_name = fq_output_layers[layer.name] + qdq_layers_dict[output_layer_name] = [quantize_layer, dequantize_layer] + elif layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: # index 0 is weight, index 1 is bias - q_layer_name = "Q" + layer["class_name"] - # this is for inbounds search - q_name = layer["config"]["name"] + q_layer_class = "Q" + layer.__class__.__name__ # for layers that have weights - if layer["config"]["name"] in self.layer_weights: - kernel = self.layer_weights[layer["config"]["name"]][0] + if layer.name in self.layer_weights: + kernel = self.layer_weights[layer.name][0] dim = list(range(0, kernel.ndim)) t_dim = [dim.pop(-1)] t_dim.extend(dim) channel_size = kernel.shape[-1] kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1) - layer_config["min_value"] = json.dumps(np.min(kernel_channel, axis=1).tolist()) - layer_config["max_value"] = json.dumps(np.max(kernel_channel, axis=1).tolist()) + + layer.min_value = np.min(kernel_channel, axis=1).tolist() + layer.max_value = np.max(kernel_channel, axis=1).tolist() else: # default value, but never expected to be used # cause no kernel weights for this layer - layer_config["min_value"] = json.dumps([-10000]) - layer_config["max_value"] = json.dumps([10000]) - layer_config["name"] = q_name - q_layer = {"class_name": q_layer_name, "name": q_name, "config": layer_config} - if "inbound_nodes" in layer: - q_layer["inbound_nodes"] = inbound_reverse_map[layer["name"]] - q_layers.append(q_layer) - else: - q_layers.append(layer) + layer.min_value = [-10000] + layer.max_value = [10000] - json_model["config"]["layers"] = q_layers - quantized_model = self._restore_model_from_json(json_model) - return quantized_model + from neural_compressor.tensorflow.keras.layers import layer_initializer_dict - def convert_bf16(self): - """Execute the BF16 conversion.""" - tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) - - for layer in json_model["config"]["layers"]: - if layer["config"]["name"] in self.bf16_ops: - layer["config"]["dtype"] = "mixed_bfloat16" - - converted_model = self._restore_model_from_json(json_model) - tf.keras.mixed_precision.set_global_policy("float32") - - return converted_model - - # (TODO) choose the properly quantize mode - def _check_quantize_mode(self, json_model): - """Check what quantize mode to use.""" - config = json_model["config"] - layers = config["layers"] - for idx, layer in enumerate(layers): - if "ReLU" in layer["class_name"]: - return "MIN_FIRST" - return "SCALED" + q_layer = layer_initializer_dict[q_layer_class](layer) + quantized_layers_dict[layer.name] = q_layer - def _restore_model_from_json(self, json_model): - """Generate a keras model from json files.""" - from tensorflow.keras.models import model_from_json + qdq_surgery = KerasSurgery(self.pre_optimized_model) + quantized_model = qdq_surgery.insert_quant_layers(qdq_layers_dict, quantized_layers_dict) + quantized_model = self._set_weights(quantized_model, self.layer_weights) - from neural_compressor.tensorflow.utils import version1_gte_version2 + quantized_model.save(self.tmp_dir) + quantized_model = tf.keras.models.load_model(self.tmp_dir) - if version1_gte_version2(keras.__version__, "2.13.1"): - from keras.src.saving import serialization_lib - - serialization_lib.enable_unsafe_deserialization() - - custom_objects = {} - # We need to keep a dictionary of custom objects as our quantized library - # is not recognized by keras. - custom_objects = _add_supported_quantized_objects(custom_objects) - json_model_file = json.dumps(json_model) - qmodel = model_from_json(json_model_file, custom_objects=custom_objects) - qmodel = self._set_weights(qmodel, self.layer_weights) - return qmodel - - # set fp32 weights to qmodel - def _set_weights(self, qmodel, layer_weights): - for qlayer in qmodel.layers: - if qlayer.get_weights(): - if qlayer.name in layer_weights: - qlayer.set_weights(layer_weights[qlayer.name]) - else: - hit_layer = False - for sub_layer in qlayer.submodules: - if sub_layer.name in layer_weights: - qlayer.set_weights(layer_weights[sub_layer.name]) - hit_layer = True - break - if not hit_layer: - raise ValueError("Can not match the module weights....") - return qmodel + return quantized_model @dump_elapsed_time(customized_msg="Model inference") def evaluate( @@ -605,11 +523,11 @@ def query_fw_capability(self, model): other_config = copy.deepcopy(op_capability["int8"]["default"]) # # get fp32 layer weights - keras_object = model + self.fp32_model = model self.conv_weights = {} self.bn_weights = {} self.layer_weights = {} - for layer in keras_object.layers: + for layer in self.fp32_model.layers: if layer.get_weights(): if ( isinstance(layer, tf.keras.layers.Conv2D) @@ -620,30 +538,27 @@ def query_fw_capability(self, model): elif isinstance(layer, tf.keras.layers.BatchNormalization): self.bn_weights[layer.name] = copy.deepcopy(layer.get_weights()) self.layer_weights[layer.name] = copy.deepcopy(layer.get_weights()) - self.pre_optimized_object = self._pre_optimize(keras_object) - - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) - config = json_model["config"] - self.fp32_layers = config["layers"] - - quantizable_op_details = OrderedDict() - for details in self.fp32_layers: - node_op = details["class_name"] - node_name = details["config"]["name"] - if node_op == "Conv2D": - quantizable_op_details[(node_name, node_op)] = [conv_config, bf16_config, fp32_config] - elif node_op == "Dense": - quantizable_op_details[(node_name, node_op)] = [dense_config, bf16_config, fp32_config] - elif node_op in {"AveragePooling2D", "AvgPool2D"}: - quantizable_op_details[(node_name, node_op)] = [avgpool_config, bf16_config, fp32_config] - elif node_op in {"MaxPooling2D", "MaxPool2D"}: - quantizable_op_details[(node_name, node_op)] = [maxpool_config, bf16_config, fp32_config] + + self._check_quantize_format(self.fp32_model) + self.pre_optimized_model = self._fuse_bn(self.fp32_model) + + quantizable_layer_details = OrderedDict() + for layer in self.fp32_model.layers: + layer_class = layer.__class__.__name__ + if layer_class == "Conv2D": + quantizable_layer_details[(layer.name, layer_class)] = [conv_config, bf16_config, fp32_config] + elif layer_class == "Dense": + quantizable_layer_details[(layer.name, layer_class)] = [dense_config, bf16_config, fp32_config] + elif layer_class in {"AveragePooling2D", "AvgPool2D"}: + quantizable_layer_details[(layer.name, layer_class)] = [avgpool_config, bf16_config, fp32_config] + elif layer_class in {"MaxPooling2D", "MaxPool2D"}: + quantizable_layer_details[(layer.name, layer_class)] = [maxpool_config, bf16_config, fp32_config] else: - quantizable_op_details[(node_name, node_op)] = [bf16_config, fp32_config] + quantizable_layer_details[(layer.name, layer_class)] = [bf16_config, fp32_config] capability = { - "opwise": copy.deepcopy(quantizable_op_details), - "optypewise": self.get_optype_wise_ability(quantizable_op_details), + "opwise": copy.deepcopy(quantizable_layer_details), + "optypewise": self.get_optype_wise_ability(quantizable_layer_details), } logger.debug("Dump framework quantization capability:") logger.debug(capability) @@ -665,6 +580,54 @@ def get_optype_wise_ability(self, quantizable_op_details): res[op[1]]["weight"] = quantizable_op_details[op][0]["weight"] return res + def tuning_cfg_to_fw(self, tuning_cfg): + """Parse tune_config and set framework variables. + + Args: + tuning_cfg (dict): The dict of tuning config. + """ + self.quantize_config["calib_iteration"] = tuning_cfg["calib_iteration"] + self.quantize_config["device"] = self.device + self.quantize_config["advance"] = deep_get(tuning_cfg, "advance") + fp32_ops = [] + bf16_ops = [] + bf16_type = set(self.query_handler.get_op_types_by_precision(precision="bf16")) + dispatched_op_names = [j[0] for j in tuning_cfg["op"]] + invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names] + + for op_name in invalid_op_names: + self.quantize_config["op_wise_config"].pop(op_name) + + for each_op_info in tuning_cfg["op"]: + op_name = each_op_info[0] + + if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": + if each_op_info[1] in bf16_type: + bf16_ops.append(op_name) + continue + + if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32": + if op_name in self.quantize_config["op_wise_config"]: + self.quantize_config["op_wise_config"].pop(op_name) + fp32_ops.append(op_name) + continue + + is_perchannel = False + bit = None + if "weight" in tuning_cfg["op"][each_op_info]: + is_perchannel = tuning_cfg["op"][each_op_info]["weight"]["granularity"] == "per_channel" + # bit = tuning_cfg['op'][each_op_info]['weight']['bit'] + weight_bit = bit if bit else 7.0 + algorithm = tuning_cfg["op"][each_op_info]["activation"]["algorithm"] + is_asymmetric = False + if "activation" in tuning_cfg["op"][each_op_info]: + is_asymmetric = tuning_cfg["op"][each_op_info]["activation"]["scheme"] == "asym" + self.quantize_config["op_wise_config"][op_name] = (is_perchannel, algorithm, is_asymmetric, weight_bit) + self.bf16_ops = bf16_ops + if self.bf16_ops: + self.bf16_ops.pop(-1) + self.fp32_ops = fp32_ops + class KerasQuery: """Class that queries configs from yaml settings.""" @@ -760,7 +723,7 @@ def get_op_types_by_precision(self, precision): class KerasConfigConverter: """Convert `StaticQuantConfig` to the format used by static quant algo.""" - support_int8_weight = {"Dense", "Conv2d", "DepthwiseConv2D", "SeparableConv2D"} + support_int8_weight = {"Dense", "Conv2D", "DepthwiseConv2D", "SeparableConv2D"} def __init__(self, quant_config: StaticQuantConfig, calib_iteration: int): """Init parser for keras static quant config. @@ -809,3 +772,119 @@ def parse_to_tune_cfg(self) -> Dict: tune_cfg["calib_iteration"] = self.calib_iteration return tune_cfg + + +class KerasSurgery: + """The class that inserts FakeQuant or QDQ layers before the target layers.""" + + def __init__(self, model): + """Init the KerasSurgery class. + + Args: + model: the model to be modified. + """ + self.model_outputs = [] + self.model = copy.deepcopy(model) + + def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): + """Create a input_layer_dict from model. + + Args: + fuse_layers: The layers in which fused BNs have been excluded, default to be None. + conv_weights_keys: The names of conv layers where BNs are going to be fused, default to be None. + + Returns: + input_layer_dict: The dict that mapping for layer names to their input layer names. + """ + input_layer_dict = {} + layers = fuse_layers if fuse_layers else self.model.layers + for layer in layers: + for node in layer._outbound_nodes: + out_layer = node.outbound_layer + out_layer_names = [out_layer.name] + if ( + conv_weights_keys + and out_layer.__class__.__name__ in ("BatchNormalization") + and layer.name in conv_weights_keys + ): + out_layer_names = [node.outbound_layer.name for node in out_layer._outbound_nodes] + + for out_layer_name in out_layer_names: + if out_layer_name not in input_layer_dict: + input_layer_dict[out_layer_name] = [layer.name] + else: + input_layer_dict[out_layer_name].append(layer.name) + + return input_layer_dict + + def fuse_bn_layers(self, fuse_layers, conv_weights_keys): + """Fuse BN layers and rebuild the model. + + Args: + fuse_layers: The layers in which fused BNs have been excluded. + conv_weights_keys: The names of conv layers where BNs are going to be fused. + """ + self.input_layer_dict = self._create_input_dict(fuse_layers, conv_weights_keys) + output_tensor_dict = {"keras.Input": self.model.input} + + for idx, layer in enumerate(fuse_layers): + if layer.__class__.__name__ == "InputLayer": + output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] + continue + + input_tensors = ( + output_tensor_dict["keras.Input"] + if idx == 0 + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + ) + + while isinstance(input_tensors, list) and len(input_tensors) == 1: + input_tensors = input_tensors[0] + + x = layer(input_tensors) + + output_tensor_dict[layer.name] = x + if layer.name in self.model.output_names: + self.model_outputs.append(x) + + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) + + def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): + """Insert FakeQuant or QDQ layers before the target layers and replace + Keras layers to Quantized layers. + + Args: + qdq_layer_dict: The dict mapping from layers to be quantized to the FakeQuant layer or QDQ layers + that are going to be inserted before them. + q_layer_dict: The dict mapping from layers to be replacement to the quantized layers. + """ + self.input_layer_dict = self._create_input_dict() + output_tensor_dict = {"keras.Input": self.model.input} + + for idx, layer in enumerate(self.model.layers): + if layer.__class__.__name__ == "InputLayer": + output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] + continue + + input_tensors = ( + output_tensor_dict["keras.Input"] + if idx == 0 + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + ) + while isinstance(input_tensors, list) and len(input_tensors) == 1: + input_tensors = input_tensors[0] + + if layer.name in qdq_layer_dict: + x = input_tensors + for inserted_layer in qdq_layer_dict[layer.name]: + x = inserted_layer(x) + cur_layer = layer if not q_layer_dict else q_layer_dict[layer.name] + x = cur_layer(x) + else: + x = layer(input_tensors) + + output_tensor_dict[layer.name] = x + if layer.name in self.model.output_names: + self.model_outputs.append(x) + + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) diff --git a/neural_compressor/tensorflow/keras/layers/__init__.py b/neural_compressor/tensorflow/keras/layers/__init__.py index 2be4fd9417e..0b4fe9030ac 100644 --- a/neural_compressor/tensorflow/keras/layers/__init__.py +++ b/neural_compressor/tensorflow/keras/layers/__init__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,3 +21,4 @@ from neural_compressor.tensorflow.keras.layers.pool2d import QAvgPool2D, QMaxPool2D from neural_compressor.tensorflow.keras.layers.quantizer import DeQuantize, FakeQuant, Quantize from neural_compressor.tensorflow.keras.layers.separable_conv2d import QSeparableConv2D +from neural_compressor.tensorflow.keras.layers.layer_initializer import layer_initializer_dict diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 812b6caaa33..0a4852d2027 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,11 +23,7 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401 - - Conv = BaseConv -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401 @@ -36,6 +32,7 @@ class QConv2D(Conv): def __init__( self, + name, filters, kernel_size, strides=(1, 1), @@ -52,11 +49,12 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - min_value=-10000, - max_value=10000, + min_value=None, + max_value=None, **kwargs ): super(QConv2D, self).__init__( + name=name, rank=2, filters=filters, kernel_size=kernel_size, @@ -76,10 +74,17 @@ def __init__( bias_constraint=constraints.get(bias_constraint), **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value def call(self, inputs): + kernel_size = self.kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000] * kernel_size + if not self.max_value: + self.max_value = [10000] * kernel_size + # add the Q/DQ here kernel, _, _ = quantization.quantize( self.kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" @@ -111,3 +116,69 @@ def call(self, inputs): @classmethod def from_config(cls, config): return cls(**config) + + +def initialize_int8_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "filters" in kwargs: + del kwargs["filters"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "groups" in kwargs: + del kwargs["groups"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "kernel_initializer" in kwargs: + del kwargs["kernel_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "kernel_regularizer" in kwargs: + del kwargs["kernel_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "kernel_constraint" in kwargs: + del kwargs["kernel_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QConv2D( + name=fp32_layer.name, + filters=fp32_layer.filters, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + groups=fp32_layer.groups, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + kernel_initializer=fp32_layer.kernel_initializer, + bias_initializer=fp32_layer.bias_initializer, + kernel_regularizer=fp32_layer.kernel_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + kernel_constraint=fp32_layer.kernel_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) diff --git a/neural_compressor/tensorflow/keras/layers/dense.py b/neural_compressor/tensorflow/keras/layers/dense.py index b97e9759b70..61dfda2a2b8 100644 --- a/neural_compressor/tensorflow/keras/layers/dense.py +++ b/neural_compressor/tensorflow/keras/layers/dense.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ class QDense(Dense): def __init__( self, + name, units, activation=None, use_bias=True, @@ -36,11 +37,12 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - min_value=-10000, - max_value=10000, + min_value=None, + max_value=None, **kwargs ): super(QDense, self).__init__( + name=name, units=units, activation=activation, use_bias=use_bias, @@ -53,10 +55,17 @@ def __init__( bias_constraint=bias_constraint, **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value def call(self, inputs): + kernel_size = self.kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000] * kernel_size + if not self.max_value: + self.max_value = [10000] * kernel_size + # add the Q/DQ here kernel, _, _ = quantization.quantize( self.kernel, @@ -66,6 +75,7 @@ def call(self, inputs): axis=1, mode="SCALED", ) + kernel = quantization.dequantize( kernel, self.min_value, @@ -80,3 +90,53 @@ def call(self, inputs): if self.activation is not None: outputs = self.activation(outputs) return outputs + + +def initialize_int8_dense(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "units" in kwargs: + del kwargs["units"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "kernel_initializer" in kwargs: + del kwargs["kernel_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "kernel_regularizer" in kwargs: + del kwargs["kernel_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "kernel_constraint" in kwargs: + del kwargs["kernel_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer = QDense( + name=fp32_layer.name, + units=fp32_layer.units, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + kernel_initializer=fp32_layer.kernel_initializer, + bias_initializer=fp32_layer.bias_initializer, + kernel_regularizer=fp32_layer.kernel_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + kernel_constraint=fp32_layer.kernel_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index eb0e9249c15..a3e6dd9b2f4 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,202 +23,188 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src import ops - from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401 -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 -if version1_gte_version2(tf.__version__, "2.16.1"): - - class QDepthwiseConv2D(BaseDepthwiseConv): - def __init__( - self, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, - **kwargs - ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - input_channel = self._get_input_channel(inputs.shape) - outputs = ops.depthwise_conv( - inputs, - self.kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format, - ) - - if self.use_bias: - if self.data_format == "channels_last": - bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) - else: - bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank - bias = ops.reshape(self.bias, bias_shape) - outputs += bias - - if self.activation is not None: - return self.activation(outputs) - return outputs - -else: - class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, +class QDepthwiseConv2D(DepthwiseConv): + def __init__( + self, + kernel_size, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, + **kwargs + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, - **kwargs - ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length( - rows, - self.kernel_size[0], - self.padding, - self.strides[0], - self.dilation_rate[0], - ) - cols = conv_utils.conv_output_length( - cols, - self.kernel_size[1], - self.padding, - self.strides[1], - self.dilation_rate[1], - ) - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) + ) + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000] * depthwise_kernel_size + if not self.max_value: + self.max_value = [10000] * depthwise_kernel_size + + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + outputs = tf.keras.backend.depthwise_conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == "channels_last": + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length( + rows, + self.kernel_size[0], + self.padding, + self.strides[0], + self.dilation_rate[0], + ) + cols = conv_utils.conv_output_length( + cols, + self.kernel_size[1], + self.padding, + self.strides[1], + self.dilation_rate[1], + ) + if self.data_format == "channels_first": + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == "channels_last": + return (input_shape[0], rows, cols, out_filters) + + +def initialize_int8_depthwise_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + q_name = fp32_layer.name + + if "name" in kwargs: + del kwargs["name"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "depth_multiplier" in kwargs: + del kwargs["depth_multiplier"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "depthwise_initializer" in kwargs: + del kwargs["depthwise_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "depthwise_regularizer" in kwargs: + del kwargs["depthwise_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "depthwise_constraint" in kwargs: + del kwargs["depthwise_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QDepthwiseConv2D( + name=q_name, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + depth_multiplier=fp32_layer.depth_multiplier, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + depthwise_initializer=fp32_layer.depthwise_initializer, + bias_initializer=fp32_layer.bias_initializer, + depthwise_regularizer=fp32_layer.depthwise_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + depthwise_constraint=fp32_layer.depthwise_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) diff --git a/neural_compressor/tensorflow/keras/layers/layer_initializer.py b/neural_compressor/tensorflow/keras/layers/layer_initializer.py new file mode 100644 index 00000000000..d1db0eb3504 --- /dev/null +++ b/neural_compressor/tensorflow/keras/layers/layer_initializer.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.tensorflow.keras.layers.conv2d import initialize_int8_conv2d +from neural_compressor.tensorflow.keras.layers.dense import initialize_int8_dense +from neural_compressor.tensorflow.keras.layers.depthwise_conv2d import initialize_int8_depthwise_conv2d +from neural_compressor.tensorflow.keras.layers.pool2d import initialize_int8_avgpool, initialize_int8_maxpool +from neural_compressor.tensorflow.keras.layers.separable_conv2d import initialize_int8_separable_conv2d + +layer_initializer_dict = { + "QAvgPool2D": initialize_int8_avgpool, + "QAveragePooling2D": initialize_int8_avgpool, + "QMaxPool2D": initialize_int8_maxpool, + "QMaxPooling2D": initialize_int8_maxpool, + "QSeparableConv2D": initialize_int8_separable_conv2d, + "QDepthwiseConv2D": initialize_int8_depthwise_conv2d, + "QConv2D": initialize_int8_conv2d, + "QDense": initialize_int8_dense, +} diff --git a/neural_compressor/tensorflow/keras/layers/pool2d.py b/neural_compressor/tensorflow/keras/layers/pool2d.py index 409c16b9305..05a028ecc83 100644 --- a/neural_compressor/tensorflow/keras/layers/pool2d.py +++ b/neural_compressor/tensorflow/keras/layers/pool2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ class QAvgPool2D(AveragePooling2D): def __init__( self, + name, pool_size=(2, 2), strides=None, padding="valid", @@ -35,15 +36,16 @@ def __init__( **kwargs ): super(QAvgPool2D, self).__init__( - pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs + name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value class QMaxPool2D(MaxPooling2D): def __init__( self, + name, pool_size=(2, 2), strides=None, padding="valid", @@ -53,7 +55,71 @@ def __init__( **kwargs ): super(QMaxPool2D, self).__init__( - pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs + name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value + + +def initialize_int8_avgpool(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "pool_size" in kwargs: + del kwargs["pool_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer = QAvgPool2D( + name=fp32_layer.name, + pool_size=fp32_layer.pool_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer + + +def initialize_int8_maxpool(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "pool_size" in kwargs: + del kwargs["pool_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer = QMaxPool2D( + name=fp32_layer.name, + pool_size=fp32_layer.pool_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer diff --git a/neural_compressor/tensorflow/keras/layers/quantizer.py b/neural_compressor/tensorflow/keras/layers/quantizer.py index bf17933756e..a6e31fc6a5c 100644 --- a/neural_compressor/tensorflow/keras/layers/quantizer.py +++ b/neural_compressor/tensorflow/keras/layers/quantizer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ def __init__(self, mode="per_tensor", T="s8", **kwargs): self.mode = mode self.T = T self.axis = 1 if mode == "per_channel" else 0 - self.min_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) - self.max_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) + self.min_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) + self.max_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) def call(self, inputs): if self.mode == "per_tensor": @@ -36,6 +36,7 @@ def call(self, inputs): else: self.min_value = tf.math.reduce_min(inputs, axis=self.axis) self.max_value = tf.math.reduce_max(inputs, axis=self.axis) + return inputs def compute_output_shape(self, input_shape): diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 07ebc691373..7df66d9db49 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,196 +23,187 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src import ops - from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv # pylint: disable=E0401 -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.src.utils import conv_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.utils import conv_utils # pylint: disable=E0401 -if version1_gte_version2(tf.__version__, "2.16.1"): - - class QSeparableConv2D(BaseSeparableConv): - def __init__( - self, - filters, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), - **kwargs - ) - - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # (TODO) it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - outputs = ops.separable_conv( - inputs, - self.depthwise_kernel, - self.pointwise_kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format, - ) - - if self.use_bias: - if self.data_format == "channels_last": - bias_shape = (1,) * (self.rank + 1) + (self.filters,) - else: - bias_shape = (1, self.filters) + (1,) * self.rank - bias = ops.reshape(self.bias, bias_shape) - outputs += bias - - if self.activation is not None: - return self.activation(outputs) - return outputs - -else: - class QSeparableConv2D(SeparableConv): - def __init__( - self, - filters, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, +class QSeparableConv2D(SeparableConv): + def __init__( + self, + name, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, + **kwargs + ): + super().__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), **kwargs - ): - super().__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), - **kwargs - ) - - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - if self.data_format == "channels_last": - strides = (1,) + self.strides + (1,) - else: - strides = (1, 1) + self.strides - # (TODO) it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - outputs = tf.compat.v1.nn.separable_conv2d( - inputs, - depthwise_kernel, - self.pointwise_kernel, - strides=strides, - padding=self.padding.upper(), - rate=self.dilation_rate, - data_format=conv_utils.convert_data_format(self.data_format, ndim=4), - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) + ) + + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000] * depthwise_kernel_size + if not self.max_value: + self.max_value = [10000] * depthwise_kernel_size + + # TODO it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + if self.data_format == "channels_last": + strides = (1,) + self.strides + (1,) + else: + strides = (1, 1) + self.strides + + outputs = tf.compat.v1.nn.separable_conv2d( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=self.dilation_rate, + data_format=conv_utils.convert_data_format(self.data_format, ndim=4), + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def initialize_int8_separable_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "filters" in kwargs: + del kwargs["filters"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "depth_multiplier" in kwargs: + del kwargs["depth_multiplier"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "depthwise_initializer" in kwargs: + del kwargs["depthwise_initializer"] + if "pointwise_initializer" in kwargs: + del kwargs["pointwise_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "depthwise_regularizer" in kwargs: + del kwargs["depthwise_regularizer"] + if "pointwise_regularizer" in kwargs: + del kwargs["pointwise_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "depthwise_constraint" in kwargs: + del kwargs["depthwise_constraint"] + if "pointwise_constraint" in kwargs: + del kwargs["pointwise_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QSeparableConv2D( + name=fp32_layer.name, + filters=fp32_layer.filters, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + depth_multiplier=fp32_layer.depth_multiplier, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + depthwise_initializer=fp32_layer.depthwise_initializer, + pointwise_initializer=fp32_layer.pointwise_initializer, + bias_initializer=fp32_layer.bias_initializer, + depthwise_regularizer=fp32_layer.depthwise_regularizer, + pointwise_regularizer=fp32_layer.pointwise_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + depthwise_constraint=fp32_layer.depthwise_constraint, + pointwise_constraint=fp32_layer.pointwise_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) diff --git a/neural_compressor/tensorflow/keras/quantization/config.py b/neural_compressor/tensorflow/keras/quantization/config.py index a46a7375ca9..ae532dc63c4 100644 --- a/neural_compressor/tensorflow/keras/quantization/config.py +++ b/neural_compressor/tensorflow/keras/quantization/config.py @@ -114,7 +114,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model) -> List[Tuple[str, Callable]]: white_list = [ "Dense", - "Conv2d", + "Conv2D", "DepthwiseConv2D", "SeparableConv2D", "AvgPool2D", diff --git a/requirements_tf.txt b/requirements_tf.txt index f8075c2a068..da1544d2939 100644 --- a/requirements_tf.txt +++ b/requirements_tf.txt @@ -3,4 +3,4 @@ psutil py-cpuinfo pydantic pyyaml -tensorflow +tensorflow<=2.15.1 diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 0e6d70f75f1..c204d52b330 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -21,9 +21,9 @@ import time import unittest +import keras import numpy as np import tensorflow as tf -from tensorflow import keras from neural_compressor.common import Logger @@ -48,15 +48,15 @@ def build_model(): [ keras.layers.InputLayer(input_shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu", name="conv2d"), keras.layers.MaxPooling2D(pool_size=(2, 2)), keras.layers.Flatten(), - keras.layers.Dense(10), + keras.layers.Dense(10, name="dense"), ] ) # Train the digit classification model model.compile( - optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + optimizer="adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] ) model.fit( @@ -69,7 +69,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model.keras") + model.save("baseline_model") class Dataset(object): @@ -124,13 +124,15 @@ def test_static_quant_from_dict_default(self): from neural_compressor.tensorflow.keras import get_default_static_quant_config calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) for layer in qmodel.layers: if layer.name == "dense": self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_static_quant_from_dict_beginner(self): logger.info("test_static_quant_from_dict_beginner") @@ -149,13 +151,15 @@ def test_static_quant_from_dict_beginner(self): } } calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) for layer in qmodel.layers: if layer.name == "dense": self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_static_quant_from_class_default(self): logger.info("test_static_quant_from_class_default") @@ -163,7 +167,7 @@ def test_static_quant_from_class_default(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = StaticQuantConfig() qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -171,6 +175,8 @@ def test_static_quant_from_class_default(self): for layer in qmodel.layers: if layer.name == "dense": self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_static_quant_from_class_beginner(self): logger.info("test_static_quant_from_class_beginner") @@ -178,7 +184,7 @@ def test_static_quant_from_class_beginner(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = StaticQuantConfig( weight_dtype="int8", weight_sym=True, @@ -193,13 +199,15 @@ def test_static_quant_from_class_beginner(self): for layer in qmodel.layers: if layer.name == "dense": self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_static_quant_from_dict_advance(self): logger.info("test_static_quant_from_dict_advance") from neural_compressor.tensorflow import quantize_model calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = { "static_quant": { "global": { @@ -222,8 +230,8 @@ def test_static_quant_from_dict_advance(self): self.assertIsNotNone(qmodel) for layer in qmodel.layers: - if layer.name == "dense": - self.assertNotEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_static_quant_from_class_advance(self): logger.info("test_static_quant_from_class_advance") @@ -245,13 +253,13 @@ def test_static_quant_from_class_advance(self): ) quant_config.set_local("dense", dense_config) # get model and quantize - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) for layer in qmodel.layers: - if layer.name == "dense": - self.assertNotEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + self.assertEqual(layer.__class__.__name__, "QConv2D") def test_config_from_dict(self): logger.info("test_config_from_dict") diff --git a/test/3x/tensorflow/quantization/test_smooth_quant.py b/test/3x/tensorflow/quantization/test_smooth_quant.py index ee8f5407d3a..5c76eadb9cd 100644 --- a/test/3x/tensorflow/quantization/test_smooth_quant.py +++ b/test/3x/tensorflow/quantization/test_smooth_quant.py @@ -20,13 +20,15 @@ def build_conv_graph(): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") + normed = tf.compat.v1.layers.batch_normalization(conv) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") + normed2 = tf.compat.v1.layers.batch_normalization(conv2) + add = tf.raw_ops.Add(x=normed, y=normed2, name="addv2") - add = tf.raw_ops.Add(x=conv, y=conv2, name="addv2") relu = tf.nn.relu(add) relu6 = tf.nn.relu6(relu, name="op_to_store") diff --git a/test/3x/tensorflow/test_autotune.py b/test/3x/tensorflow/test_autotune.py index d5f83e85c7d..9c89f8cd5fc 100644 --- a/test/3x/tensorflow/test_autotune.py +++ b/test/3x/tensorflow/test_autotune.py @@ -59,7 +59,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - tf.saved_model.save(model, "baseline_model") + model.save("baseline_model") class Dataset(object):