Skip to content

Commit

Permalink
Fix Keras3 Issues in TF 2.16.1 for 3.0 new API (#1669)
Browse files Browse the repository at this point in the history
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel authored Mar 13, 2024
1 parent 62aa85d commit 047560f
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 223 deletions.
6 changes: 5 additions & 1 deletion neural_compressor/tensorflow/keras/layers/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

from neural_compressor.tensorflow.utils import version1_gte_version2

if version1_gte_version2(tf.__version__, "2.13.0"):
if version1_gte_version2(tf.__version__, "2.16.1"):
from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401

Conv = BaseConv
elif version1_gte_version2(tf.__version__, "2.13.0"):
from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401
else:
from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401
Expand Down
297 changes: 191 additions & 106 deletions neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,117 +23,202 @@

from neural_compressor.tensorflow.utils import version1_gte_version2

if version1_gte_version2(tf.__version__, "2.13.0"):
if version1_gte_version2(tf.__version__, "2.16.1"):
from keras.src import ops
from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401
elif version1_gte_version2(tf.__version__, "2.13.0"):
from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401
from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401
else:
from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401
from keras.utils import conv_utils, tf_utils # pylint: disable=E0401

if version1_gte_version2(tf.__version__, "2.16.1"):

class QDepthwiseConv2D(DepthwiseConv):
def __init__(
self,
kernel_size,
min_value,
max_value,
strides=(1, 1),
padding="valid",
depth_multiplier=1,
data_format=None,
dilation_rate=(1, 1),
activation=None,
use_bias=True,
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
2,
kernel_size=kernel_size,
strides=strides,
padding=padding,
depth_multiplier=depth_multiplier,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
depthwise_initializer=depthwise_initializer,
bias_initializer=bias_initializer,
depthwise_regularizer=depthwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
depthwise_constraint=depthwise_constraint,
bias_constraint=bias_constraint,
class QDepthwiseConv2D(BaseDepthwiseConv):
def __init__(
self,
kernel_size,
min_value,
max_value,
strides=(1, 1),
padding="valid",
depth_multiplier=1,
data_format=None,
dilation_rate=(1, 1),
activation=None,
use_bias=True,
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs
)
self.min_value = json.loads(min_value)
self.max_value = json.loads(max_value)

def call(self, inputs):
# add the Q/DQ here
kernel, _, _ = quantization.quantize(
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
)
kernel = quantization.dequantize(
kernel,
self.min_value,
self.max_value,
axis=3,
mode="SCALED",
)
outputs = tf.keras.backend.depthwise_conv2d(
inputs,
kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)

if self.use_bias:
outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format)

if self.activation is not None:
return self.activation(outputs)

return outputs

@classmethod
def from_config(cls, config):
return cls(**config)

@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
rows = input_shape[2]
cols = input_shape[3]
out_filters = input_shape[1] * self.depth_multiplier
elif self.data_format == "channels_last":
rows = input_shape[1]
cols = input_shape[2]
out_filters = input_shape[3] * self.depth_multiplier

rows = conv_utils.conv_output_length(
rows,
self.kernel_size[0],
self.padding,
self.strides[0],
self.dilation_rate[0],
)
cols = conv_utils.conv_output_length(
cols,
self.kernel_size[1],
self.padding,
self.strides[1],
self.dilation_rate[1],
)
if self.data_format == "channels_first":
return (input_shape[0], out_filters, rows, cols)
elif self.data_format == "channels_last":
return (input_shape[0], rows, cols, out_filters)
):
super().__init__(
2,
kernel_size=kernel_size,
strides=strides,
padding=padding,
depth_multiplier=depth_multiplier,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
depthwise_initializer=depthwise_initializer,
bias_initializer=bias_initializer,
depthwise_regularizer=depthwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
depthwise_constraint=depthwise_constraint,
bias_constraint=bias_constraint,
**kwargs
)
self.min_value = json.loads(min_value)
self.max_value = json.loads(max_value)

def call(self, inputs):
# add the Q/DQ here
kernel, _, _ = quantization.quantize(
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
)
kernel = quantization.dequantize(
kernel,
self.min_value,
self.max_value,
axis=3,
mode="SCALED",
)

input_channel = self._get_input_channel(inputs.shape)
outputs = ops.depthwise_conv(
inputs,
self.kernel,
strides=self.strides,
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)

if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,)
else:
bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias

if self.activation is not None:
return self.activation(outputs)
return outputs

else:

class QDepthwiseConv2D(DepthwiseConv):
def __init__(
self,
kernel_size,
min_value,
max_value,
strides=(1, 1),
padding="valid",
depth_multiplier=1,
data_format=None,
dilation_rate=(1, 1),
activation=None,
use_bias=True,
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
2,
kernel_size=kernel_size,
strides=strides,
padding=padding,
depth_multiplier=depth_multiplier,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
depthwise_initializer=depthwise_initializer,
bias_initializer=bias_initializer,
depthwise_regularizer=depthwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
depthwise_constraint=depthwise_constraint,
bias_constraint=bias_constraint,
**kwargs
)
self.min_value = json.loads(min_value)
self.max_value = json.loads(max_value)

def call(self, inputs):
# add the Q/DQ here
kernel, _, _ = quantization.quantize(
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
)
kernel = quantization.dequantize(
kernel,
self.min_value,
self.max_value,
axis=3,
mode="SCALED",
)
outputs = tf.keras.backend.depthwise_conv2d(
inputs,
kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)

if self.use_bias:
outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format)

if self.activation is not None:
return self.activation(outputs)

return outputs

@classmethod
def from_config(cls, config):
return cls(**config)

@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
rows = input_shape[2]
cols = input_shape[3]
out_filters = input_shape[1] * self.depth_multiplier
elif self.data_format == "channels_last":
rows = input_shape[1]
cols = input_shape[2]
out_filters = input_shape[3] * self.depth_multiplier

rows = conv_utils.conv_output_length(
rows,
self.kernel_size[0],
self.padding,
self.strides[0],
self.dilation_rate[0],
)
cols = conv_utils.conv_output_length(
cols,
self.kernel_size[1],
self.padding,
self.strides[1],
self.dilation_rate[1],
)
if self.data_format == "channels_first":
return (input_shape[0], out_filters, rows, cols)
elif self.data_format == "channels_last":
return (input_shape[0], rows, cols, out_filters)
12 changes: 12 additions & 0 deletions neural_compressor/tensorflow/keras/layers/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def call(self, inputs):
self.max_value = tf.math.reduce_max(inputs, axis=self.axis)
return inputs

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
return input_shape

@classmethod
def from_config(cls, config):
return cls(**config)
Expand Down Expand Up @@ -87,6 +91,10 @@ def call(self, inputs):
)
return outputs

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
return input_shape

def get_config(self):
return {
"min_range": self.min_range,
Expand Down Expand Up @@ -122,6 +130,10 @@ def call(self, inputs):
axis=self.axis,
)

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
return input_shape

def get_config(self):
return {
"min_range": self.min_range,
Expand Down
Loading

0 comments on commit 047560f

Please sign in to comment.