diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 3bcf1b07b86..812b6caaa33 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -23,7 +23,11 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401 + + Conv = BaseConv +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401 diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 1de9d8bf792..eb0e9249c15 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -23,117 +23,202 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 +if version1_gte_version2(tf.__version__, "2.16.1"): -class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, + class QDepthwiseConv2D(BaseDepthwiseConv): + def __init__( + self, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, **kwargs - ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length( - rows, - self.kernel_size[0], - self.padding, - self.strides[0], - self.dilation_rate[0], - ) - cols = conv_utils.conv_output_length( - cols, - self.kernel_size[1], - self.padding, - self.strides[1], - self.dilation_rate[1], - ) - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + self.min_value = json.loads(min_value) + self.max_value = json.loads(max_value) + + def call(self, inputs): + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + input_channel = self._get_input_channel(inputs.shape) + outputs = ops.depthwise_conv( + inputs, + self.kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) + else: + bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs + +else: + + class QDepthwiseConv2D(DepthwiseConv): + def __init__( + self, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + self.min_value = json.loads(min_value) + self.max_value = json.loads(max_value) + + def call(self, inputs): + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + outputs = tf.keras.backend.depthwise_conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == "channels_last": + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length( + rows, + self.kernel_size[0], + self.padding, + self.strides[0], + self.dilation_rate[0], + ) + cols = conv_utils.conv_output_length( + cols, + self.kernel_size[1], + self.padding, + self.strides[1], + self.dilation_rate[1], + ) + if self.data_format == "channels_first": + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == "channels_last": + return (input_shape[0], rows, cols, out_filters) diff --git a/neural_compressor/tensorflow/keras/layers/quantizer.py b/neural_compressor/tensorflow/keras/layers/quantizer.py index b395870b48f..bf17933756e 100644 --- a/neural_compressor/tensorflow/keras/layers/quantizer.py +++ b/neural_compressor/tensorflow/keras/layers/quantizer.py @@ -38,6 +38,10 @@ def call(self, inputs): self.max_value = tf.math.reduce_max(inputs, axis=self.axis) return inputs + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + @classmethod def from_config(cls, config): return cls(**config) @@ -87,6 +91,10 @@ def call(self, inputs): ) return outputs + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + def get_config(self): return { "min_range": self.min_range, @@ -122,6 +130,10 @@ def call(self, inputs): axis=self.axis, ) + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + def get_config(self): return { "min_range": self.min_range, diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 5507d2f99d2..07ebc691373 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -23,102 +23,196 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.src.utils import conv_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.utils import conv_utils # pylint: disable=E0401 +if version1_gte_version2(tf.__version__, "2.16.1"): -class QSeparableConv2D(SeparableConv): - def __init__( - self, - filters, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), + class QSeparableConv2D(BaseSeparableConv): + def __init__( + self, + filters, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, **kwargs - ) - - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - if self.data_format == "channels_last": - strides = (1,) + self.strides + (1,) - else: - strides = (1, 1) + self.strides - # (TODO) it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - outputs = tf.compat.v1.nn.separable_conv2d( - inputs, - depthwise_kernel, - self.pointwise_kernel, - strides=strides, - padding=self.padding.upper(), - rate=self.dilation_rate, - data_format=conv_utils.convert_data_format(self.data_format, ndim=4), - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + + self.min_value = json.loads(min_value) + self.max_value = json.loads(max_value) + + def call(self, inputs): + # (TODO) it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + outputs = ops.separable_conv( + inputs, + self.depthwise_kernel, + self.pointwise_kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs + +else: + + class QSeparableConv2D(SeparableConv): + def __init__( + self, + filters, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + + self.min_value = json.loads(min_value) + self.max_value = json.loads(max_value) + + def call(self, inputs): + if self.data_format == "channels_last": + strides = (1,) + self.strides + (1,) + else: + strides = (1, 1) + self.strides + # (TODO) it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + outputs = tf.compat.v1.nn.separable_conv2d( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=self.dilation_rate, + data_format=conv_utils.convert_data_format(self.data_format, ndim=4), + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/neural_compressor/tensorflow/utils/model_wrappers.py b/neural_compressor/tensorflow/utils/model_wrappers.py index 8781bd2aa8b..88b8871ba81 100644 --- a/neural_compressor/tensorflow/utils/model_wrappers.py +++ b/neural_compressor/tensorflow/utils/model_wrappers.py @@ -32,7 +32,7 @@ from neural_compressor.common import logger from neural_compressor.common.utils import DEFAULT_WORKSPACE -from neural_compressor.tensorflow.utils.utility import version1_lt_version2 +from neural_compressor.tensorflow.utils.utility import version1_gte_version2, version1_lt_version2 tensor_to_node = lambda s: list(set([x.split(":")[0] for x in s])) @@ -91,7 +91,9 @@ def get_model_type(model): return "graph" elif isinstance(model, tf.compat.v1.GraphDef): return "graph_def" - elif isinstance(model, tf.compat.v1.estimator.Estimator): + elif not version1_gte_version2(tf.version.VERSION, "2.16.1") and isinstance( + model, tf.compat.v1.estimator.Estimator + ): return "estimator" elif isinstance(model, str): model = os.path.abspath(os.path.expanduser(model)) diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index a47eab6bf27..0e6d70f75f1 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -69,7 +69,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model") + model.save("baseline_model.keras") class Dataset(object): @@ -124,7 +124,7 @@ def test_static_quant_from_dict_default(self): from neural_compressor.tensorflow.keras import get_default_static_quant_config calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) @@ -149,7 +149,7 @@ def test_static_quant_from_dict_beginner(self): } } calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -163,7 +163,7 @@ def test_static_quant_from_class_default(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") quant_config = StaticQuantConfig() qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -178,7 +178,7 @@ def test_static_quant_from_class_beginner(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") quant_config = StaticQuantConfig( weight_dtype="int8", weight_sym=True, @@ -199,7 +199,7 @@ def test_static_quant_from_dict_advance(self): from neural_compressor.tensorflow import quantize_model calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") quant_config = { "static_quant": { "global": { @@ -245,7 +245,7 @@ def test_static_quant_from_class_advance(self): ) quant_config.set_local("dense", dense_config) # get model and quantize - fp32_model = keras.models.load_model("./baseline_model") + fp32_model = keras.models.load_model("baseline_model.keras") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) diff --git a/test/3x/tensorflow/quantization/test_smooth_quant.py b/test/3x/tensorflow/quantization/test_smooth_quant.py index 8aac76addbc..ee8f5407d3a 100644 --- a/test/3x/tensorflow/quantization/test_smooth_quant.py +++ b/test/3x/tensorflow/quantization/test_smooth_quant.py @@ -20,18 +20,18 @@ def build_conv_graph(): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - normed = tf.compat.v1.layers.batch_normalization(conv) conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") - normed2 = tf.compat.v1.layers.batch_normalization(conv2) - add = tf.raw_ops.Add(x=normed, y=normed2, name="addv2") + + add = tf.raw_ops.Add(x=conv, y=conv2, name="addv2") relu = tf.nn.relu(add) relu6 = tf.nn.relu6(relu, name="op_to_store") out_name = relu6.name.split(":")[0] + with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) output_graph_def = graph_util.convert_variables_to_constants( @@ -120,11 +120,26 @@ def test_sq_from_dict_beginner(self): self.assertEqual(mul_count, 2) def test_sq_completed_workflow(self): + x_data = np.random.rand(1024, 1024).astype(np.float32) + y_data = np.random.rand(1024, 1024).astype(np.float32) + import tensorflow.compat.v1 as tf + + with tf.Session(graph=tf.Graph()) as sess: + x = tf.placeholder(tf.float32, shape=[1024, 1024], name="x") + y = tf.constant(y_data, dtype=tf.float32, shape=[1024, 1024]) + z = tf.matmul(x, y) + bias = np.random.rand(1024).astype(np.float32) + z = tf.nn.bias_add(z, bias) + z = tf.nn.relu(z, name="op_to_store") + sess.run(z, feed_dict={x: x_data, y: y_data}) + output_graph_def = sess.graph.as_graph_def() + + set_random_seed(9527) sq_config = SmoothQuantConfig(alpha=0.5) static_config = StaticQuantConfig() - dataset = DummyDataset(shape=(100, 56, 56, 16), label=True) - calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1) - q_model = quantize_model(self.conv_graph, [sq_config, static_config], calib_dataloader, calib_iteration=500) + dataset = DummyDataset(shape=(1024, 1024), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1024) + q_model = quantize_model(output_graph_def, [sq_config, static_config], calib_dataloader, calib_iteration=500) mul_count = 0 quantized = False @@ -134,7 +149,7 @@ def test_sq_completed_workflow(self): if "quantize" in i.op: quantized = True - self.assertEqual(mul_count, 2) + self.assertEqual(mul_count, 1) self.assertEqual(quantized, True) @disable_random() @@ -143,14 +158,13 @@ def test_matmul(self): y_data = np.random.rand(1024, 1024).astype(np.float32) import tensorflow.compat.v1 as tf - x = tf.placeholder(tf.float32, shape=[1024, 1024], name="x") - y = tf.constant(y_data, dtype=tf.float32, shape=[1024, 1024]) - z = tf.matmul(x, y) - bias = np.random.rand(1024).astype(np.float32) - z = tf.nn.bias_add(z, bias) - z = tf.nn.relu(z, name="op_to_store") - - with tf.Session() as sess: + with tf.Session(graph=tf.Graph()) as sess: + x = tf.placeholder(tf.float32, shape=[1024, 1024], name="x") + y = tf.constant(y_data, dtype=tf.float32, shape=[1024, 1024]) + z = tf.matmul(x, y) + bias = np.random.rand(1024).astype(np.float32) + z = tf.nn.bias_add(z, bias) + z = tf.nn.relu(z, name="op_to_store") sess.run(z, feed_dict={x: x_data, y: y_data}) output_graph_def = sess.graph.as_graph_def() @@ -190,6 +204,7 @@ def test_conv_matmul(self): leaky_relu = tf.nn.leaky_relu(conv2, name="op_to_store") out_name = leaky_relu.name.split(":")[0] + with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) output_graph_def = graph_util.convert_variables_to_constants( diff --git a/test/3x/tensorflow/test_autotune.py b/test/3x/tensorflow/test_autotune.py index 9c89f8cd5fc..d5f83e85c7d 100644 --- a/test/3x/tensorflow/test_autotune.py +++ b/test/3x/tensorflow/test_autotune.py @@ -59,7 +59,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model") + tf.saved_model.save(model, "baseline_model") class Dataset(object):