diff --git a/keras/backend/mxnet_backend.py b/keras/backend/mxnet_backend.py index 088f5558123f..1a20f5a4533b 100644 --- a/keras/backend/mxnet_backend.py +++ b/keras/backend/mxnet_backend.py @@ -3010,8 +3010,55 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1), # Raises ValueError: if `data_format` is neither `channels_last` or `channels_first`. + + # Example (detailed example refer to mnist_denoising_autoencoder.py): + ``` + >>>from keras.models import Sequential + >>>from keras.layers import Conv2DTranspose + >>>model = Sequential() + >>>model.add(Conv2DTranspose(32, (3, 3), activation='relu', + >>> input_shape=(100, 100, 3))) + >>>model.summary() + _________________________________________________________________ + Layer (type) Output Shape Param # + ================================================================= + conv2d_transpose_2 (Conv2DTr (None, 32, 102, 5) 28832 + ================================================================= + ``` """ - raise NotImplementedError('MXNet Backend: conv2d_transpose operator is not supported yet.') + if data_format is None: + data_format = image_data_format() + _validate_data_format(data_format) + + if padding not in {'same', 'valid'}: + raise ValueError('`padding` should be either `same` or `valid`.') + + # Handle Data Format + x = _preprocess_convnd_input(x, data_format) + kernel = _preprocess_convnd_kernel(kernel) + + # We have already converted kernel to match MXNet required shape: + # (depth, input_depth, rows, cols) + kernel_shape = kernel.shape + layout_kernel = tuple(kernel_shape[2:]) + nb_filter = kernel_shape[1] + + # Handle output shape to suit mxnet input format + if data_format == 'channels_first': + output_shape = output_shape[2:] + else: + output_shape = output_shape[1:-1] + + # Performance transpose convolution + deconv = mx.sym.Deconvolution(data=x.symbol, name=kernel.name, + kernel=layout_kernel, stride=strides, + num_filter=nb_filter, weight=kernel.symbol, + no_bias=True, target_shape=output_shape) + + # Handle original Data Format + result = _postprocess_convnd_output(KerasSymbol(deconv, is_var=True), + data_format) + return result def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1), @@ -3986,6 +4033,14 @@ def _preprocess_convnd_kernel(kernel): return kernel +@keras_mxnet_symbol +def _preprocess_convnd_transpose_output(output_shape, data_format): + if data_format == 'channels_last': + output_shape = output_shape[1:-1] + elif data_format == 'channels_first': + output_shape = output_shape[2:] + return output_shape + def _validate_conv_input_shape(input_shape): # MXNet convolution operator cannot automatically infer shape. # Feature requirement - @@ -4172,8 +4227,8 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None, trainable_weights = set([x.name for x in self.trainable_weights]) self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights] - self._args = {x: bind_values[x] for x in self._arg_names} - self._auxs = {x: bind_values[x] for x in self._aux_names} + self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values} + self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values} self._weights_dirty = False # set the module @@ -4323,7 +4378,6 @@ def _create_predict_module(self): self._data_names = [x.name for x in self.inputs if x] state_updates = [x[1] for x in self.state_updates] - # set for prediction self._npred = len(self.outputs) pred_keras_symbol = group( @@ -4341,8 +4395,8 @@ def _create_predict_module(self): trainable_weights = set([x.name for x in self.trainable_weights]) self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights] - self._args = {x: bind_values[x] for x in self._arg_names} - self._auxs = {x: bind_values[x] for x in self._aux_names} + self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values} + self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values} self._weights_dirty = False # set module for prediction only diff --git a/tests/keras/engine/test_topology.py b/tests/keras/engine/test_topology.py index 4102fda641dc..c62889f473bf 100644 --- a/tests/keras/engine/test_topology.py +++ b/tests/keras/engine/test_topology.py @@ -618,8 +618,6 @@ def test_preprocess_weights_for_loading(layer): for (x, y) in zip(weights1, weights2)]) -@pytest.mark.skipif(K.backend() == 'mxnet', - reason='MXNet backend does not support Conv2D Transpose') @keras_test @pytest.mark.parametrize("layer", [ layers.Conv2D(2, (3, 3), input_shape=[5, 5, 3]), diff --git a/tests/keras/layers/convolutional_test.py b/tests/keras/layers/convolutional_test.py index 47e6bd9824eb..e357cfd20136 100644 --- a/tests/keras/layers/convolutional_test.py +++ b/tests/keras/layers/convolutional_test.py @@ -197,8 +197,6 @@ def test_convolution_2d(): batch_input_shape=(None, None, 5, None))]) -@pytest.mark.skipif((K.backend() == 'mxnet'), - reason='MXNet backend does not support conv2d_transpose yet.') @keras_test def test_conv2d_transpose(): num_samples = 2