Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

implement sparse categorical crossentropy, enable unitests #145

Merged
merged 10 commits into from
Jul 29, 2018
125 changes: 101 additions & 24 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,12 +1614,20 @@ def equal(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x == y, dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8'))
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '==' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x == y, dtype='uint8'))
else:
raise TypeError('MXNet Backend: The inputs are not valid for equal operation.')
# use numpy if x and x are all numbers or numpy arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x and y*

# raise type error when inputs are invalid
try:
out = np.equal(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for equal operation.')
return out


@keras_mxnet_symbol
Expand All @@ -1640,12 +1648,20 @@ def not_equal(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x != y, dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8'))
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '!=' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x != y, dtype='uint8'))
else:
raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.')
# use numpy if x and x are all numbers or numpy arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x and y*

# raise type error when inputs are invalid
try:
out = np.not_equal(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.')
return out


@keras_mxnet_symbol
Expand All @@ -1666,12 +1682,20 @@ def greater(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x > y, dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8'))
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it is best practice not to have multiple returns in the same function. (hard to maintain and debug). Can you restructure? Same in other places.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please follow docstring conventions as mentioned here - https://www.python.org/dev/peps/pep-0257/#one-line-docstrings

Also the comments are not necessary per operator, you can consider adding this blurb about implementation of operators at the beginning of the file depending on data type

out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '>' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x > y, dtype='uint8'))
else:
raise TypeError('MXNet Backend: The inputs are not valid for greater operation.')
# use numpy if x and x are all numbers or numpy arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x and y*

# raise type error when inputs are invalid
try:
out = np.greater(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for greater operation.')
return out


@keras_mxnet_symbol
Expand All @@ -1692,12 +1716,20 @@ def greater_equal(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8'))
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '>=' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8'))
else:
raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.')
# use numpy if x and x are all numbers or numpy arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x & y*

# raise type error when inputs are invalid
try:
out = np.greater_equal(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.')
return out


@keras_mxnet_symbol
Expand All @@ -1718,9 +1750,20 @@ def less(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x < y, dtype='uint8'))
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '<' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x < y, dtype='uint8'))
else:
# use numpy if x and x are all numbers or numpy arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x & y*

# raise type error when inputs are invalid
try:
out = np.less(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for less operation.')
return out


@keras_mxnet_symbol
Expand All @@ -1741,9 +1784,20 @@ def less_equal(x, y):
if isinstance(y, KerasSymbol):
y = y.symbol
scalar = True
if scalar:
return KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8'))
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8'))
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol):
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8'))
elif scalar:
# directly use '<=' operator for element-wise comparison
out = KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8'))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] check style

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[adjust spaces around '=']

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyc usually there is no space for optional params values, refer to any other operators.

else:
# use numpy if x and x are all numbers or numpy arrays
# raise type error when inputs are invalid
try:
out = np.less(x, y)
except:
raise TypeError('MXNet Backend: The inputs are not valid for less_equal operation.')
return out


@keras_mxnet_symbol
Expand Down Expand Up @@ -2864,7 +2918,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
# Returns
Output tensor.
"""
output_dimensions = list(range(len(int_shape(output))))
output_dimensions = list(range(ndim(output)))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
Expand All @@ -2888,6 +2942,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
return KerasSymbol(mx_output)


@keras_mxnet_symbol
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.

Expand All @@ -2902,7 +2957,29 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
# Returns
Output tensor.
"""
raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.')
output_dimensions = list(range(ndim(output)))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(int_shape(output)))))

mx_output = output.symbol
# scale predictions so that the class probas of each sample sum to 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above about writing the docstring - https://www.python.org/dev/peps/pep-0257/#one-line-docstrings

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyc compare to element-wise comparison operators (e.g. K.equal), sparse_categorical_crossentropy and categorical_crossentropy are more complicated, and it requires different backend to do different processing logic. It's important to place the inline comments. Placing these comments in doc string will confuse users. See tensorflow_backend.py for reference.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!
[minor] change probas to probabilities

if from_logits:
mx_output = mx.sym.softmax(mx_output, axis=axis)
else:
mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output,
axis=axis,
keepdims=True))
# clip to prevent NaN's and Inf's
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'infinities'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'nan' and 'inf' are python values

mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon())
# For this operation, the probability of a given label is considered exclusive.
mx_output = mx.sym.pick(mx_output, target.symbol, axis=axis, keepdims=True)
mx_output = - mx.sym.log(mx_output, axis=axis)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do? Why are you making this value negative? Why not use mx.sym here instead of -?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# reshape to input's shape
return reshape(KerasSymbol(mx_output), target.shape)


@keras_mxnet_symbol
Expand Down
8 changes: 2 additions & 6 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,12 +1400,8 @@ def prepare_simple_model(input_tensor, loss_name, target):
simple_model.compile(optimizer='rmsprop', loss=loss)
return simple_model

# MXNet backend does not support Sparse Categorical Crossentropy yet.
if K.backend() == 'mxnet':
losses_to_test = ['categorical_crossentropy', 'binary_crossentropy']
else:
losses_to_test = ['sparse_categorical_crossentropy',
'categorical_crossentropy', 'binary_crossentropy']
losses_to_test = ['sparse_categorical_crossentropy',
'categorical_crossentropy', 'binary_crossentropy']

data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55],
[0.9, 4.2, 11.2]]]], dtype=np.float32)
Expand Down
6 changes: 0 additions & 6 deletions tests/keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def test_objective_shapes_2d():
assert K.eval(objective_output).shape == (6,)


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support `sparse` yet.')
def test_cce_one_hot():
y_a = K.variable(np.random.randint(0, 7, (5, 6)))
y_b = K.variable(np.random.random((5, 6, 7)))
Expand All @@ -82,8 +80,6 @@ def test_categorical_hinge():
assert np.isclose(expected_loss, np.mean(loss))


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support `sparse` yet.')
def test_sparse_categorical_crossentropy():
y_pred = K.variable(np.array([[0.3, 0.6, 0.1],
[0.1, 0.2, 0.7]]))
Expand All @@ -93,8 +89,6 @@ def test_sparse_categorical_crossentropy():
assert np.isclose(expected_loss, np.mean(loss))


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support `sparse` yet.')
def test_sparse_categorical_crossentropy_4d():
y_pred = K.variable(np.array([[[[0.7, 0.1, 0.2],
[0.0, 0.3, 0.7],
Expand Down
2 changes: 0 additions & 2 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def test_metrics():
assert K.eval(output).shape == (6,)


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support `sparse` yet.')
def test_sparse_metrics():
for metric in all_sparse_metrics:
y_a = K.variable(np.random.randint(0, 7, (6,)), dtype=K.floatx())
Expand Down