Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add sparse concat operator support #167

Merged
merged 3 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_REENTRY = False
NAME_SCOPE_STACK = []

py_all = all
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to define py_all here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following the convention in tensorflow_backend file



class name_scope(object):
def __init__(self, name):
Expand Down Expand Up @@ -2030,6 +2032,10 @@ def concatenate(tensors, axis=-1):

# Returns
A tensor.

Note:
- MXNet supports sparse concat only for dim=0
- https://mxnet.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.concat
"""
if axis < 0:
rank = ndim(tensors[0])
Expand All @@ -2038,8 +2044,12 @@ def concatenate(tensors, axis=-1):
else:
axis = 0

tensors = [t.symbol for t in tensors]
return KerasSymbol(mx.sym.concat(*tensors, dim=axis))
symbols = [t.symbol for t in tensors]

if axis == 0 and py_all([is_sparse(t) for t in tensors]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is preferrable to avoid multiple returns in a func.

return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=axis))

return KerasSymbol(mx.sym.concat(*symbols, dim=axis))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes. Users should know that tensor is being converted to dense.



@keras_mxnet_symbol
Expand Down
58 changes: 58 additions & 0 deletions tests/keras/backend/mxnet_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,64 @@ def test_sparse_dot(self):
assert k_s.shape == k_d.shape
assert_allclose(k_s, k_d, atol=1e-05)

def test_sparse_concat(self):
x_sparse_1 = self.generate_test_sparse_matrix()
x_sparse_2 = self.generate_test_sparse_matrix()

assert K.is_sparse(K.variable(x_sparse_1))
assert K.is_sparse(K.variable(x_sparse_2))
x_dense_1 = x_sparse_1.toarray()
x_dense_2 = x_sparse_2.toarray()

k_s = K.concatenate(tensors=[K.variable(x_sparse_1), K.variable(x_sparse_2)], axis=0)
assert K.is_sparse(k_s)

k_s_d = K.eval(k_s)

# mx.sym.sparse.concat only supported for axis=0
k_d = K.eval(K.concatenate(tensors=[K.variable(x_dense_1), K.variable(x_dense_2)], axis=0))

assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)

def test_sparse_concat_partial_dense(self):
x_sparse_1 = self.generate_test_sparse_matrix()
x_sparse_2 = self.generate_test_sparse_matrix()

assert K.is_sparse(K.variable(x_sparse_1))
x_dense_1 = x_sparse_1.toarray()
x_dense_2 = x_sparse_2.toarray()

k_s = K.concatenate(tensors=[K.variable(x_sparse_1), K.variable(x_dense_2)], axis=0)
assert not(K.is_sparse(k_s))

k_s_d = K.eval(k_s)

# mx.sym.sparse.concat only supported for axis=0
k_d = K.eval(K.concatenate(tensors=[K.variable(x_dense_1), K.variable(x_dense_2)], axis=0))

assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)

def test_sparse_concat_axis_non_zero(self):
x_sparse_1 = self.generate_test_sparse_matrix()
x_sparse_2 = self.generate_test_sparse_matrix()

assert K.is_sparse(K.variable(x_sparse_1))
x_dense_1 = x_sparse_1.toarray()
x_dense_2 = x_sparse_2.toarray()

k_s = K.concatenate(tensors=[K.variable(x_sparse_1), K.variable(x_dense_2)])
assert not (K.is_sparse(k_s))

k_s_d = K.eval(k_s)

# mx.sym.sparse.concat only supported for axis=0
k_d = K.eval(K.concatenate(tensors=[K.variable(x_dense_1), K.variable(x_dense_2)]))

assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)


if __name__ == '__main__':
pytest.main([__file__])