-
Notifications
You must be signed in to change notification settings - Fork 65
Add sparse concat operator support #167
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
_REENTRY = False | ||
NAME_SCOPE_STACK = [] | ||
|
||
py_all = all | ||
|
||
|
||
class name_scope(object): | ||
def __init__(self, name): | ||
|
@@ -2030,6 +2032,10 @@ def concatenate(tensors, axis=-1): | |
|
||
# Returns | ||
A tensor. | ||
|
||
Note: | ||
- MXNet supports sparse concat only for dim=0 | ||
- https://mxnet.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.concat | ||
""" | ||
if axis < 0: | ||
rank = ndim(tensors[0]) | ||
|
@@ -2038,8 +2044,12 @@ def concatenate(tensors, axis=-1): | |
else: | ||
axis = 0 | ||
|
||
tensors = [t.symbol for t in tensors] | ||
return KerasSymbol(mx.sym.concat(*tensors, dim=axis)) | ||
symbols = [t.symbol for t in tensors] | ||
|
||
if py_all([is_sparse(t) for t in tensors]): | ||
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should still check if axis !=0, even if all tensors are sparse, if user want to concat on axis other than 0, we use dense concat There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have hard-coded the dimension - why do we need to check the axis!=0 here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed hard-coded value & added unit test |
||
|
||
return KerasSymbol(mx.sym.concat(*symbols, dim=axis)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If part of tensors are sparse, so it will fail the check at There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned in the doc - it will convert the concat output to be dense. To retrieve concat output as sparse all input tensors should be sparse. Added a unit test to check the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think yes. Users should know that tensor is being converted to dense. |
||
|
||
|
||
@keras_mxnet_symbol | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,35 @@ def test_sparse_dot(self): | |
assert k_s.shape == k_d.shape | ||
assert_allclose(k_s, k_d, atol=1e-05) | ||
|
||
def test_sparse_concat(self): | ||
x_d = np.array([0, 7, 2, 3], dtype=np.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use sparse data generator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
x_r = np.array([0, 2, 2, 3], dtype=np.int64) | ||
x_c = np.array([4, 3, 2, 3], dtype=np.int64) | ||
|
||
x_sparse_1 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5)) | ||
|
||
x_d = np.array([0, 7, 2, 3], dtype=np.float32) | ||
x_r = np.array([0, 2, 2, 3], dtype=np.int64) | ||
x_c = np.array([4, 3, 2, 3], dtype=np.int64) | ||
|
||
x_sparse_2 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5)) | ||
|
||
assert K.is_sparse(K.variable(x_sparse_1)) | ||
assert K.is_sparse(K.variable(x_sparse_2)) | ||
x_dense_1 = x_sparse_1.toarray() | ||
x_dense_2 = x_sparse_2.toarray() | ||
|
||
k_s = K.concatenate(tensors=[K.variable(x_sparse_1), K.variable(x_sparse_2)]) | ||
assert K.is_sparse(k_s) | ||
|
||
k_s_d = K.eval(k_s) | ||
|
||
# mx.sym.sparse.concat only supported for axis=0 | ||
k_d = K.eval(K.concatenate(tensors=[K.variable(x_dense_1), K.variable(x_dense_2)], axis=0)) | ||
|
||
assert k_s_d.shape == k_d.shape | ||
assert_allclose(k_s_d, k_d, atol=1e-05) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to define py_all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following the convention in
tensorflow_backend
file