-
Notifications
You must be signed in to change notification settings - Fork 65
implement sparse categorical crossentropy, enable unitests #145
Changes from 8 commits
386b663
92d03f2
9ee1801
2c1f0f4
037424c
2e46670
ddf4769
1938a08
84628e8
991ab74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1614,12 +1614,20 @@ def equal(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x == y, dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8')) | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_equal(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '==' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x == y, dtype='uint8')) | ||
else: | ||
raise TypeError('MXNet Backend: The inputs are not valid for equal operation.') | ||
# use numpy if x and x are all numbers or numpy arrays | ||
# raise type error when inputs are invalid | ||
try: | ||
out = np.equal(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for equal operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -1640,12 +1648,20 @@ def not_equal(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x != y, dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8')) | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_not_equal(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '!=' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x != y, dtype='uint8')) | ||
else: | ||
raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.') | ||
# use numpy if x and x are all numbers or numpy arrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x and y* |
||
# raise type error when inputs are invalid | ||
try: | ||
out = np.not_equal(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for not_equal operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -1666,12 +1682,20 @@ def greater(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x > y, dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8')) | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually it is best practice not to have multiple returns in the same function. (hard to maintain and debug). Can you restructure? Same in other places. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please follow docstring conventions as mentioned here - https://www.python.org/dev/peps/pep-0257/#one-line-docstrings Also the comments are not necessary per operator, you can consider adding this blurb about implementation of operators at the beginning of the file depending on data type |
||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '>' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x > y, dtype='uint8')) | ||
else: | ||
raise TypeError('MXNet Backend: The inputs are not valid for greater operation.') | ||
# use numpy if x and x are all numbers or numpy arrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x and y* |
||
# raise type error when inputs are invalid | ||
try: | ||
out = np.greater(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for greater operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -1692,12 +1716,20 @@ def greater_equal(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8')) | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_greater_equal(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '>=' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x >= y, dtype='uint8')) | ||
else: | ||
raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.') | ||
# use numpy if x and x are all numbers or numpy arrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x & y* |
||
# raise type error when inputs are invalid | ||
try: | ||
out = np.greater_equal(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for greater_equal operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -1718,9 +1750,20 @@ def less(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x < y, dtype='uint8')) | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '<' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x < y, dtype='uint8')) | ||
else: | ||
# use numpy if x and x are all numbers or numpy arrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if x & y* |
||
# raise type error when inputs are invalid | ||
try: | ||
out = np.less(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for less operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -1741,9 +1784,20 @@ def less_equal(x, y): | |
if isinstance(y, KerasSymbol): | ||
y = y.symbol | ||
scalar = True | ||
if scalar: | ||
return KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8')) | ||
return KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8')) | ||
if isinstance(x, mx.sym.Symbol) and isinstance(y, mx.sym.Symbol): | ||
# use broadcasting to do element-wise comparison if both x and y are mxnet symbol | ||
out = KerasSymbol(mx.sym.Cast(mx.sym.broadcast_lesser_equal(lhs=x, rhs=y), dtype='uint8')) | ||
elif scalar: | ||
# directly use '<=' operator for element-wise comparison | ||
out = KerasSymbol(mx.sym.Cast(x <= y, dtype='uint8')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [minor] check style There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [adjust spaces around '='] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kalyc usually there is no space for optional params values, refer to any other operators. |
||
else: | ||
# use numpy if x and x are all numbers or numpy arrays | ||
# raise type error when inputs are invalid | ||
try: | ||
out = np.less(x, y) | ||
except: | ||
raise TypeError('MXNet Backend: The inputs are not valid for less_equal operation.') | ||
return out | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
@@ -2864,7 +2918,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): | |
# Returns | ||
Output tensor. | ||
""" | ||
output_dimensions = list(range(len(int_shape(output)))) | ||
output_dimensions = list(range(ndim(output))) | ||
if axis != -1 and axis not in output_dimensions: | ||
raise ValueError( | ||
'{}{}{}'.format( | ||
|
@@ -2888,6 +2942,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): | |
return KerasSymbol(mx_output) | ||
|
||
|
||
@keras_mxnet_symbol | ||
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): | ||
"""Categorical crossentropy with integer targets. | ||
|
||
|
@@ -2902,7 +2957,29 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): | |
# Returns | ||
Output tensor. | ||
""" | ||
raise NotImplementedError('MXNet Backend: Sparse operations are not supported yet.') | ||
output_dimensions = list(range(ndim(output))) | ||
if axis != -1 and axis not in output_dimensions: | ||
raise ValueError( | ||
'{}{}{}'.format( | ||
'Unexpected channels axis {}. '.format(axis), | ||
'Expected to be -1 or one of the axes of `output`, ', | ||
'which has {} dimensions.'.format(len(int_shape(output))))) | ||
|
||
mx_output = output.symbol | ||
# scale predictions so that the class probas of each sample sum to 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see comment above about writing the docstring - https://www.python.org/dev/peps/pep-0257/#one-line-docstrings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kalyc compare to element-wise comparison operators (e.g. K.equal), sparse_categorical_crossentropy and categorical_crossentropy are more complicated, and it requires different backend to do different processing logic. It's important to place the inline comments. Placing these comments in doc string will confuse users. See tensorflow_backend.py for reference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation! |
||
if from_logits: | ||
mx_output = mx.sym.softmax(mx_output, axis=axis) | ||
else: | ||
mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output, | ||
axis=axis, | ||
keepdims=True)) | ||
# clip to prevent NaN's and Inf's | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'infinities' There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'nan' and 'inf' are python values |
||
mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) | ||
# For this operation, the probability of a given label is considered exclusive. | ||
mx_output = mx.sym.pick(mx_output, target.symbol, axis=axis, keepdims=True) | ||
mx_output = - mx.sym.log(mx_output, axis=axis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this do? Why are you making this value negative? Why not use mx.sym here instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to loss function equations: http://ml-cheatsheet.readthedocs.io/en/latest/loss_functions.html#cross-entropy |
||
# reshape to input's shape | ||
return reshape(KerasSymbol(mx_output), target.shape) | ||
|
||
|
||
@keras_mxnet_symbol | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if x and y*