Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add sparse concat operator support #167

Merged
merged 3 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_REENTRY = False
NAME_SCOPE_STACK = []

py_all = all
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to define py_all here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following the convention in tensorflow_backend file



class name_scope(object):
def __init__(self, name):
Expand Down Expand Up @@ -2030,6 +2032,10 @@ def concatenate(tensors, axis=-1):

# Returns
A tensor.

Note:
- MXNet supports sparse concat only for dim=0
- https://mxnet.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.concat
"""
if axis < 0:
rank = ndim(tensors[0])
Expand All @@ -2038,8 +2044,12 @@ def concatenate(tensors, axis=-1):
else:
axis = 0

tensors = [t.symbol for t in tensors]
return KerasSymbol(mx.sym.concat(*tensors, dim=axis))
symbols = [t.symbol for t in tensors]

if py_all([is_sparse(t) for t in tensors]):
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should still check if axis !=0, even if all tensors are sparse, if user want to concat on axis other than 0, we use dense concat

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have hard-coded the dimension - why do we need to check the axis!=0 here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed hard-coded value & added unit test


return KerasSymbol(mx.sym.concat(*symbols, dim=axis))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If part of tensors are sparse, so it will fail the check at py_all([is_sparse(t) for t in tensors]):, what will happen if you pass them to dense concat?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the doc - it will convert the concat output to be dense. To retrieve concat output as sparse all input tensors should be sparse. Added a unit test to check the same

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes. Users should know that tensor is being converted to dense.



@keras_mxnet_symbol
Expand Down
29 changes: 29 additions & 0 deletions tests/keras/backend/mxnet_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,35 @@ def test_sparse_dot(self):
assert k_s.shape == k_d.shape
assert_allclose(k_s, k_d, atol=1e-05)

def test_sparse_concat(self):
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use sparse data generator

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)

x_sparse_1 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))

x_d = np.array([0, 7, 2, 3], dtype=np.float32)
x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)

x_sparse_2 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))

assert K.is_sparse(K.variable(x_sparse_1))
assert K.is_sparse(K.variable(x_sparse_2))
x_dense_1 = x_sparse_1.toarray()
x_dense_2 = x_sparse_2.toarray()

k_s = K.concatenate(tensors=[K.variable(x_sparse_1), K.variable(x_sparse_2)])
assert K.is_sparse(k_s)

k_s_d = K.eval(k_s)

# mx.sym.sparse.concat only supported for axis=0
k_d = K.eval(K.concatenate(tensors=[K.variable(x_dense_1), K.variable(x_dense_2)], axis=0))

assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)


if __name__ == '__main__':
pytest.main([__file__])