-
Notifications
You must be signed in to change notification settings - Fork 65
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution, see comments inline
keras/backend/mxnet_backend.py
Outdated
symbols = [t.symbol for t in tensors] | ||
|
||
if py_all([is_sparse(t) for t in tensors]): | ||
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should still check if axis !=0, even if all tensors are sparse, if user want to concat on axis other than 0, we use dense concat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have hard-coded the dimension - why do we need to check the axis!=0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed hard-coded value & added unit test
keras/backend/mxnet_backend.py
Outdated
if py_all([is_sparse(t) for t in tensors]): | ||
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0)) | ||
|
||
return KerasSymbol(mx.sym.concat(*symbols, dim=axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If part of tensors are sparse, so it will fail the check at py_all([is_sparse(t) for t in tensors]):
, what will happen if you pass them to dense concat?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in the doc - it will convert the concat output to be dense. To retrieve concat output as sparse all input tensors should be sparse. Added a unit test to check the same
@@ -104,6 +104,35 @@ def test_sparse_dot(self): | |||
assert k_s.shape == k_d.shape | |||
assert_allclose(k_s, k_d, atol=1e-05) | |||
|
|||
def test_sparse_concat(self): | |||
x_d = np.array([0, 7, 2, 3], dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use sparse data generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -17,6 +17,8 @@ | |||
_REENTRY = False | |||
NAME_SCOPE_STACK = [] | |||
|
|||
py_all = all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to define py_all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following the convention in tensorflow_backend
file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM. Few minor comments.
return KerasSymbol(mx.sym.concat(*tensors, dim=axis)) | ||
symbols = [t.symbol for t in tensors] | ||
|
||
if axis == 0 and py_all([is_sparse(t) for t in tensors]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is preferrable to avoid multiple returns in a func.
if axis == 0 and py_all([is_sparse(t) for t in tensors]): | ||
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=axis)) | ||
|
||
return KerasSymbol(mx.sym.concat(*symbols, dim=axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think yes. Users should know that tensor is being converted to dense.
Summary
Add sparse support for
concat
operatorRelated Issues
Missing sparse operators
PR Overview