Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add sparse concat operator support #167

Merged
merged 3 commits into from
Sep 11, 2018

Conversation

kalyc
Copy link

@kalyc kalyc commented Sep 7, 2018

Summary

Add sparse support for concat operator

Related Issues

Missing sparse operators

PR Overview

  • This PR requires new unit tests [y/n] (make sure tests are included)
  • This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
  • This PR is backwards compatible [y/n]
  • This PR changes the current API [y/n]

Copy link

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, see comments inline

symbols = [t.symbol for t in tensors]

if py_all([is_sparse(t) for t in tensors]):
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should still check if axis !=0, even if all tensors are sparse, if user want to concat on axis other than 0, we use dense concat

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have hard-coded the dimension - why do we need to check the axis!=0 here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed hard-coded value & added unit test

if py_all([is_sparse(t) for t in tensors]):
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=0))

return KerasSymbol(mx.sym.concat(*symbols, dim=axis))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If part of tensors are sparse, so it will fail the check at py_all([is_sparse(t) for t in tensors]):, what will happen if you pass them to dense concat?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the doc - it will convert the concat output to be dense. To retrieve concat output as sparse all input tensors should be sparse. Added a unit test to check the same

@@ -104,6 +104,35 @@ def test_sparse_dot(self):
assert k_s.shape == k_d.shape
assert_allclose(k_s, k_d, atol=1e-05)

def test_sparse_concat(self):
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use sparse data generator

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -17,6 +17,8 @@
_REENTRY = False
NAME_SCOPE_STACK = []

py_all = all
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to define py_all here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following the convention in tensorflow_backend file

Copy link

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM. Few minor comments.

return KerasSymbol(mx.sym.concat(*tensors, dim=axis))
symbols = [t.symbol for t in tensors]

if axis == 0 and py_all([is_sparse(t) for t in tensors]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is preferrable to avoid multiple returns in a func.

if axis == 0 and py_all([is_sparse(t) for t in tensors]):
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=axis))

return KerasSymbol(mx.sym.concat(*symbols, dim=axis))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes. Users should know that tensor is being converted to dense.

@sandeep-krishnamurthy sandeep-krishnamurthy merged commit 5e0c10b into awslabs:dev Sep 11, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants