diff --git a/keras/backend/mxnet_backend.py b/keras/backend/mxnet_backend.py index 07a692c58e33..df1664ae4169 100644 --- a/keras/backend/mxnet_backend.py +++ b/keras/backend/mxnet_backend.py @@ -1600,12 +1600,20 @@ def clip(x, min_value, max_value): def equal(x, y): """Element-wise equality between two tensors. + For all element-wise comparison operators: + use broadcasting to do element-wise comparison if both x & y are MXNet symbol + use native comparison operators for scalar + use numpy operators if both x & y are numbers or numpy arrays + # Arguments x: Tensor or variable. y: Tensor or variable. # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1614,12 +1622,16 @@ def equal(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x == y, dtype='uint8')) if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8')) + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x == y, dtype='uint8')) else: - raise TypeError('MXNet Backend: The inputs are not valid for equal operation.') + try: + out = np.equal(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for equal operation.') + return out @keras_mxnet_symbol @@ -1632,6 +1644,9 @@ def not_equal(x, y): # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1640,12 +1655,16 @@ def not_equal(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x != y, dtype='uint8')) if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8')) + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x != y, dtype='uint8')) else: - raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.') + try: + out = np.not_equal(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.') + return out @keras_mxnet_symbol @@ -1658,6 +1677,9 @@ def greater(x, y): # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1666,12 +1688,16 @@ def greater(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x > y, dtype='uint8')) if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8')) + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x > y, dtype='uint8')) else: - raise TypeError('MXNet Backend: The inputs are not valid for greater operation.') + try: + out = np.greater(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for greater operation.') + return out @keras_mxnet_symbol @@ -1684,6 +1710,9 @@ def greater_equal(x, y): # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1692,12 +1721,16 @@ def greater_equal(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8')) if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8')) + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8')) else: - raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.') + try: + out = np.greater_equal(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.') + return out @keras_mxnet_symbol @@ -1710,6 +1743,9 @@ def less(x, y): # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1718,9 +1754,16 @@ def less(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x < y, dtype='uint8')) - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8')) + if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x < y, dtype='uint8')) + else: + try: + out = np.less(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for less operation.') + return out @keras_mxnet_symbol @@ -1733,6 +1776,9 @@ def less_equal(x, y): # Returns A bool tensor. + + # Raise + TypeError: if inputs are not valid. """ scalar = False if isinstance(x, KerasSymbol): @@ -1741,9 +1787,16 @@ def less_equal(x, y): if isinstance(y, KerasSymbol): y = y.symbol scalar = True - if scalar: - return KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8')) - return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8')) + if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): + out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8')) + elif scalar: + out = KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8')) + else: + try: + out = np.less(x, y) + except: + raise TypeError('MXNet Backend: The inputs are not valid for less_equal operation.') + return out @keras_mxnet_symbol @@ -2864,7 +2917,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): # Returns Output tensor. """ - output_dimensions = list(range(len(int_shape(output)))) + output_dimensions = list(range(ndim(output))) if axis != -1 and axis not in output_dimensions: raise ValueError( '{}{}{}'.format( @@ -2873,7 +2926,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): 'which has {} dimensions.'.format(len(int_shape(output))))) mx_output = output.symbol - # scale predictions so that the class probas of each sample sum to 1 + # scale predictions so that the class probabilities of each sample sum to 1 if from_logits: mx_output = mx.sym.softmax(mx_output, axis=axis) else: @@ -2888,6 +2941,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): return KerasSymbol(mx_output) +@keras_mxnet_symbol def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy with integer targets. @@ -2902,7 +2956,29 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): # Returns Output tensor. """ - raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.') + output_dimensions = list(range(ndim(output))) + if axis != -1 and axis not in output_dimensions: + raise ValueError( + '{}{}{}'.format( + 'Unexpected channels axis {}. '.format(axis), + 'Expected to be -1 or one of the axes of `output`, ', + 'which has {} dimensions.'.format(len(int_shape(output))))) + + mx_output = output.symbol + # scale predictions so that the class probabilities of each sample sum to 1 + if from_logits: + mx_output = mx.sym.softmax(mx_output, axis=axis) + else: + mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output, + axis=axis, + keepdims=True)) + # clip to prevent NaN's and Inf's + mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) + # For this operation, the probability of a given label is considered exclusive. + mx_output = mx.sym.pick(mx_output, target.symbol, axis=axis, keepdims=True) + mx_output = - mx.sym.log(mx_output, axis=axis) + # reshape to input's shape + return reshape(KerasSymbol(mx_output), target.shape) @keras_mxnet_symbol diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index f3e7515ccc4e..a44d1bee14a2 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -1400,12 +1400,8 @@ def prepare_simple_model(input_tensor, loss_name, target): simple_model.compile(optimizer='rmsprop', loss=loss) return simple_model - # MXNet backend does not support Sparse Categorical Crossentropy yet. - if K.backend() == 'mxnet': - losses_to_test = ['categorical_crossentropy', 'binary_crossentropy'] - else: - losses_to_test = ['sparse_categorical_crossentropy', - 'categorical_crossentropy', 'binary_crossentropy'] + losses_to_test = ['sparse_categorical_crossentropy', + 'categorical_crossentropy', 'binary_crossentropy'] data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55], [0.9, 4.2, 11.2]]]], dtype=np.float32) diff --git a/tests/keras/losses_test.py b/tests/keras/losses_test.py index 01623ce597f3..046698041190 100644 --- a/tests/keras/losses_test.py +++ b/tests/keras/losses_test.py @@ -59,8 +59,6 @@ def test_objective_shapes_2d(): assert K.eval(objective_output).shape == (6,) -@pytest.mark.skipif(K.backend() == 'mxnet', - reason='MXNet backend does not support `sparse` yet.') def test_cce_one_hot(): y_a = K.variable(np.random.randint(0, 7, (5, 6))) y_b = K.variable(np.random.random((5, 6, 7))) @@ -82,8 +80,6 @@ def test_categorical_hinge(): assert np.isclose(expected_loss, np.mean(loss)) -@pytest.mark.skipif(K.backend() == 'mxnet', - reason='MXNet backend does not support `sparse` yet.') def test_sparse_categorical_crossentropy(): y_pred = K.variable(np.array([[0.3, 0.6, 0.1], [0.1, 0.2, 0.7]])) @@ -93,8 +89,6 @@ def test_sparse_categorical_crossentropy(): assert np.isclose(expected_loss, np.mean(loss)) -@pytest.mark.skipif(K.backend() == 'mxnet', - reason='MXNet backend does not support `sparse` yet.') def test_sparse_categorical_crossentropy_4d(): y_pred = K.variable(np.array([[[[0.7, 0.1, 0.2], [0.0, 0.3, 0.7], diff --git a/tests/keras/metrics_test.py b/tests/keras/metrics_test.py index b2e8f48fbe18..de955877f486 100644 --- a/tests/keras/metrics_test.py +++ b/tests/keras/metrics_test.py @@ -57,8 +57,6 @@ def test_metrics(): assert K.eval(output).shape == (6,) -@pytest.mark.skipif(K.backend() == 'mxnet', - reason='MXNet backend does not support `sparse` yet.') def test_sparse_metrics(): for metric in all_sparse_metrics: y_a = K.variable(np.random.randint(0, 7, (6,)), dtype=K.floatx())