Skip to content

Commit

Permalink
Update qseparable_conv2d_transpose to have a single version for all p…
Browse files Browse the repository at this point in the history
…latforms.

Upgrading the CPU version to be used everywhere.

PiperOrigin-RevId: 685794266
Change-Id: I70dd34ee602d862169dbce8e6b7d5ef3ae2ebb9a
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Oct 14, 2024
1 parent cbcc62e commit 1308e5d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 181 deletions.
15 changes: 7 additions & 8 deletions qkeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,20 @@

from .b2t import * # pylint: disable=wildcard-import
from .estimate import * # pylint: disable=wildcard-import
from .qlayers import * # pylint: disable=wildcard-import
from .quantizers import * # pylint: disable=wildcard-import
from .qconv2d_batchnorm import QConv2DBatchnorm
from .qconvolutional import * # pylint: disable=wildcard-import
from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
from .qlayers import * # pylint: disable=wildcard-import
from .qmac import * # pylint: disable=wildcard-import
from .qrecurrent import * # pylint: disable=wildcard-import
from .qnormalization import * # pylint: disable=wildcard-import
from .qoctave import * # pylint: disable=wildcard-import
from .qpooling import * # pylint: disable=wildcard-import
from .safe_eval import * # pylint: disable=wildcard-import
from .qrecurrent import * # pylint: disable=wildcard-import
from .qseparable_conv2d_transpose import QSeparableConv2DTranspose
#from .qtools.run_qtools import QTools
#from .qtools.settings import cfg
from .qconv2d_batchnorm import QConv2DBatchnorm
from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
from .qseparable_conv2d_transpose import QSeparableConv2DTransposeTPU
from .qseparable_conv2d_transpose import QSeparableConv2DTransposeCPU
from .quantizers import * # pylint: disable=wildcard-import
from .safe_eval import * # pylint: disable=wildcard-import


assert tf.executing_eagerly(), "QKeras requires TF with eager execution mode on"
Expand Down
151 changes: 39 additions & 112 deletions qkeras/qseparable_conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from tensorflow.python.ops import array_ops


class QSeparableConv2DTransposeTPU(Conv2DTranspose):
"""Quantized Separable Conv2DTranspose layer for TPU and GPU."""
class QSeparableConv2DTranspose(Conv2DTranspose):
"""Quantized Separable Conv2DTranspose layer."""

# Most of these parameters follow the implementation of Conv2DTranspose
# in Keras, with the exception of following parameters.
Expand All @@ -42,17 +42,6 @@ class QSeparableConv2DTransposeTPU(Conv2DTranspose):
# we refer the reader to the documentation of Conv2DTranspose in Keras for
# the other parameters.

# Important Notes:
# This implementation requies the use of grouped convolution, which is only
# supported in TPU/GPU, not in CPU.
# When running in CPU, it gives the following error:
# "Gradients for grouped convolutions are not supported on CPU.
# Please file a feature request if you run into this issue."
# For now we can train with this implmentation in TPU/GPU,
# for inference in CPU, we will convert the layer to an equivalent
# QSeparableConv2DTransposeCPU layer, which is slow in training,
# but should suffice in inference.

def __init__(self,
filters,
kernel_size,
Expand Down Expand Up @@ -268,21 +257,48 @@ def conv_transpose_op(self, inputs, filters, strides, padding,
else:
quantized_kernel = kernel_weights

output_filters = 1 if is_depthwise else filters

if self.data_format == "channels_first":
output_shape = (batch_size, filters, out_height, out_width)
output_shape = (batch_size, output_filters, out_height, out_width)
else:
output_shape = (batch_size, out_height, out_width, filters)
output_shape = (batch_size, out_height, out_width, output_filters)

output_shape_tensor = array_ops.stack(output_shape)

outputs = tf.keras.backend.conv2d_transpose(
inputs,
quantized_kernel,
output_shape_tensor,
strides=strides,
padding=padding,
data_format=self.data_format,
dilation_rate=dilation_rate)
# Split the input channels into groups.
x = tf.split(inputs, self._input_shape[-1], axis=-1)

if is_depthwise:
# For depthwise convolution, since CPU doesn't support grouped
# convolution, we run convolution on each slice of inputs and concat
# the results.
outputs = [
tf.keras.backend.conv2d_transpose(
x=x[i],
kernel=quantized_kernel[:, :, :, i : i + 1],
output_shape=output_shape_tensor,
strides=strides,
padding=padding,
data_format=self.data_format,
dilation_rate=dilation_rate,
)
for i in range(len(x))
]

# Concat the channels.
outputs = tf.concat(outputs, axis=-1)

else:
outputs = tf.keras.backend.conv2d_transpose(
inputs,
quantized_kernel,
output_shape_tensor,
strides=strides,
padding=padding,
data_format=self.data_format,
dilation_rate=dilation_rate,
)

if not context.executing_eagerly():
# Infer the static output shape:
Expand Down Expand Up @@ -386,92 +402,3 @@ def get_prunable_weights(self):
w.append(self.bias)

return w


class QSeparableConv2DTransposeCPU(QSeparableConv2DTransposeTPU):
"""CPU version of Quantized Separable Conv2DTranspose layer.
Important Notes:
* This implementation can run on TPU, GPU and CPU. But the training speed can
be significantly slower than the TPU/GPU version.
* QSeparableConv2DTransposeCPU and QSeparableConv2DTransposeTPU layer have
the same shape on kernel and bias variables. With the same input and the same
weights, the output of the two layers are the same.
"""

def conv_transpose_op(self, inputs, filters, strides, padding,
output_padding, dilation_rate,
kernel_quantizer, kernel_weights, use_bias,
bias_quantizer, bias, activation, is_depthwise):
"""Transpose convolution op that shared by both depthwise and pointwise."""

batch_size, out_height, out_width, kernel_h, kernel_w = (
self._get_output_size(inputs, output_padding, padding, strides,
dilation_rate, kernel_weights))

if kernel_quantizer:
quantized_kernel = kernel_quantizer(kernel_weights)
else:
quantized_kernel = kernel_weights

output_filters = 1 if is_depthwise else filters

if self.data_format == "channels_first":
output_shape = (batch_size, output_filters, out_height, out_width)
else:
output_shape = (batch_size, out_height, out_width, output_filters)

output_shape_tensor = array_ops.stack(output_shape)

# Split the input channels into groups.
x = tf.split(inputs, self._input_shape[-1], axis=-1)

if is_depthwise:
# For depthwise convolution, since CPU doesn't support grouped
# convolution, we run convolution on each slice of inputs and concat
# the results.
outputs = [
tf.keras.backend.conv2d_transpose(
x=x[i],
kernel=quantized_kernel[:, :, :, i : i + 1],
output_shape=output_shape_tensor,
strides=strides,
padding=padding,
data_format=self.data_format,
dilation_rate=dilation_rate) for i in range(len(x))]

# Concat the channels.
outputs = tf.concat(outputs, axis=-1)

else:
outputs = tf.keras.backend.conv2d_transpose(
inputs,
quantized_kernel,
output_shape_tensor,
strides=strides,
padding=padding,
data_format=self.data_format,
dilation_rate=dilation_rate)

if not context.executing_eagerly():
# Infer the static output shape:
out_shape = self.compute_final_output_shape(
input_shape=inputs.shape,
kernel_size=(kernel_h, kernel_w),
strides=strides,
is_depthwise=is_depthwise)
outputs.set_shape(out_shape)

if use_bias:
quantized_bias = bias_quantizer(bias) if bias_quantizer else bias
outputs = tf.keras.backend.bias_add(
outputs,
quantized_bias,
data_format=self.data_format)

if activation is not None:
return activation(outputs)

return outputs
98 changes: 37 additions & 61 deletions tests/qseparable_conv2d_transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,53 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile

import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import tempfile
import os

import tensorflow as tf

from qkeras import QSeparableConv2DTranspose
from qkeras import quantized_bits
from qkeras import QSeparableConv2DTransposeTPU
from qkeras import QSeparableConv2DTransposeCPU


def create_model(for_tpu=True):
def create_model():
x = img_input = tf.keras.layers.Input(shape=(4, 4, 3))

if for_tpu:
x = QSeparableConv2DTransposeTPU(
filters=2, kernel_size=(2, 2),
strides=(2, 2),
padding="same", name="conv2d_tran",
depthwise_activation=None,
pointwise_activation=None,
depthwise_kernel_quantizer=None,
pointwise_kernel_quantizer=None,
bias_quantizer=None,
)(x)
else:
x = QSeparableConv2DTransposeCPU(
filters=2, kernel_size=(2, 2),
strides=(2, 2),
padding="same", name="conv2d_tran",
depthwise_activation=None,
pointwise_activation=None,
depthwise_kernel_quantizer=None,
pointwise_kernel_quantizer=None,
bias_quantizer=None,
)(x)
x = QSeparableConv2DTranspose(
filters=2,
kernel_size=(2, 2),
strides=(2, 2),
padding="same",
name="conv2d_tran",
depthwise_activation=None,
pointwise_activation=None,
depthwise_kernel_quantizer=None,
pointwise_kernel_quantizer=None,
bias_quantizer=None,
)(x)

model = tf.keras.Model(inputs=img_input, outputs=x)

return model


def create_quantized_model(for_tpu=True):
def create_quantized_model():
x = img_input = tf.keras.layers.Input(shape=(4, 4, 3))

if for_tpu:
x = QSeparableConv2DTransposeTPU(
filters=2, kernel_size=(2, 2),
strides=(1, 1),
padding="same", name="conv2d_tran",
depthwise_activation="quantized_bits(10, 6, 1)",
pointwise_activation="quantized_bits(5, 3, 1)",
depthwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
pointwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
bias_quantizer=quantized_bits(2, 2, 1, alpha=1.0)
)(x)
else:
x = QSeparableConv2DTransposeCPU(
filters=2, kernel_size=(2, 2),
strides=(1, 1),
padding="same", name="conv2d_tran",
depthwise_activation="quantized_bits(10, 6, 1)",
pointwise_activation="quantized_bits(5, 3, 1)",
depthwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
pointwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
bias_quantizer=quantized_bits(2, 2, 1, alpha=1.0)
)(x)
x = QSeparableConv2DTranspose(
filters=2,
kernel_size=(2, 2),
strides=(1, 1),
padding="same",
name="conv2d_tran",
depthwise_activation="quantized_bits(10, 6, 1)",
pointwise_activation="quantized_bits(5, 3, 1)",
depthwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
pointwise_kernel_quantizer=quantized_bits(1, 0, 1, alpha=1.0),
bias_quantizer=quantized_bits(2, 2, 1, alpha=1.0),
)(x)

model = tf.keras.Model(inputs=img_input, outputs=x)

Expand All @@ -102,8 +80,8 @@ def test_qseparable_conv2d_transpose():
# mapped from input channel(3) to output channel (2) by pointwise conv.
# Pointwise conv output is (1, 8, 8, 2).

# Create model using CPU version: QSeparableConv2DTransposeCPU.
model = create_model(for_tpu=False)
# Create model.
model = create_model()

output_shape = model.output_shape
ws = model.layers[1].weights
Expand Down Expand Up @@ -161,9 +139,8 @@ def test_qseparable_conv2d_transpose():
def test_quantization_in_separable_conv2d_transpose():
# Test if quantization is applied correctly.

# Create model using CPU version: QSeparableConv2DTransposeCPU
# with quantization.
model = create_quantized_model(for_tpu=False)
# Create model with quantization.
model = create_quantized_model()

x = np.array([[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]])
inputs = np.concatenate([x, x, x], axis=-1)
Expand Down Expand Up @@ -201,14 +178,13 @@ def test_quantization_in_separable_conv2d_transpose():

def test_save_and_load_model():
# Test if the model can be loaded from a saved model.
model = create_quantized_model(for_tpu=True)
model = create_quantized_model()

fd, fname = tempfile.mkstemp(".hdf5")
model.save(fname)

custom_object = {
"QSeparableConv2DTransposeTPU": QSeparableConv2DTransposeTPU,
"QSeparableConv2DTransposeCPU": QSeparableConv2DTransposeCPU,
"QSeparableConv2DTranspose": QSeparableConv2DTranspose,
}

model_loaded = tf.keras.models.load_model(
Expand Down

0 comments on commit 1308e5d

Please sign in to comment.