Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Conv2d transpose (#47)
Browse files Browse the repository at this point in the history
* add conv2d_transpose

* conv2d transpose for both channels, enabled test case

* add detailed comments and examples, fix style issue

* enable test case in topology
  • Loading branch information
roywei authored and sandeep-krishnamurthy committed Mar 19, 2018
1 parent c97096e commit c3c450f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
66 changes: 60 additions & 6 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3010,8 +3010,55 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
# Raises
ValueError: if `data_format` is neither `channels_last` or `channels_first`.
# Example (detailed example refer to mnist_denoising_autoencoder.py):
```
>>>from keras.models import Sequential
>>>from keras.layers import Conv2DTranspose
>>>model = Sequential()
>>>model.add(Conv2DTranspose(32, (3, 3), activation='relu',
>>> input_shape=(100, 100, 3)))
>>>model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_transpose_2 (Conv2DTr (None, 32, 102, 5) 28832
=================================================================
```
"""
raise NotImplementedError('MXNet Backend: conv2d_transpose operator is not supported yet.')
if data_format is None:
data_format = image_data_format()
_validate_data_format(data_format)

if padding not in {'same', 'valid'}:
raise ValueError('`padding` should be either `same` or `valid`.')

# Handle Data Format
x = _preprocess_convnd_input(x, data_format)
kernel = _preprocess_convnd_kernel(kernel)

# We have already converted kernel to match MXNet required shape:
# (depth, input_depth, rows, cols)
kernel_shape = kernel.shape
layout_kernel = tuple(kernel_shape[2:])
nb_filter = kernel_shape[1]

# Handle output shape to suit mxnet input format
if data_format == 'channels_first':
output_shape = output_shape[2:]
else:
output_shape = output_shape[1:-1]

# Performance transpose convolution
deconv = mx.sym.Deconvolution(data=x.symbol, name=kernel.name,
kernel=layout_kernel, stride=strides,
num_filter=nb_filter, weight=kernel.symbol,
no_bias=True, target_shape=output_shape)

# Handle original Data Format
result = _postprocess_convnd_output(KerasSymbol(deconv, is_var=True),
data_format)
return result


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
Expand Down Expand Up @@ -3986,6 +4033,14 @@ def _preprocess_convnd_kernel(kernel):
return kernel


@keras_mxnet_symbol
def _preprocess_convnd_transpose_output(output_shape, data_format):
if data_format == 'channels_last':
output_shape = output_shape[1:-1]
elif data_format == 'channels_first':
output_shape = output_shape[2:]
return output_shape

def _validate_conv_input_shape(input_shape):
# MXNet convolution operator cannot automatically infer shape.
# Feature requirement -
Expand Down Expand Up @@ -4172,8 +4227,8 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None,

trainable_weights = set([x.name for x in self.trainable_weights])
self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
self._args = {x: bind_values[x] for x in self._arg_names}
self._auxs = {x: bind_values[x] for x in self._aux_names}
self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._weights_dirty = False

# set the module
Expand Down Expand Up @@ -4323,7 +4378,6 @@ def _create_predict_module(self):
self._data_names = [x.name for x in self.inputs if x]

state_updates = [x[1] for x in self.state_updates]

# set for prediction
self._npred = len(self.outputs)
pred_keras_symbol = group(
Expand All @@ -4341,8 +4395,8 @@ def _create_predict_module(self):

trainable_weights = set([x.name for x in self.trainable_weights])
self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
self._args = {x: bind_values[x] for x in self._arg_names}
self._auxs = {x: bind_values[x] for x in self._aux_names}
self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._weights_dirty = False

# set module for prediction only
Expand Down
2 changes: 0 additions & 2 deletions tests/keras/engine/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,6 @@ def test_preprocess_weights_for_loading(layer):
for (x, y) in zip(weights1, weights2)])


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support Conv2D Transpose')
@keras_test
@pytest.mark.parametrize("layer", [
layers.Conv2D(2, (3, 3), input_shape=[5, 5, 3]),
Expand Down
2 changes: 0 additions & 2 deletions tests/keras/layers/convolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ def test_convolution_2d():
batch_input_shape=(None, None, 5, None))])


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend does not support conv2d_transpose yet.')
@keras_test
def test_conv2d_transpose():
num_samples = 2
Expand Down

0 comments on commit c3c450f

Please sign in to comment.