-
Notifications
You must be signed in to change notification settings - Fork 65
Add sparse concat operator support #167
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
_REENTRY = False | ||
NAME_SCOPE_STACK = [] | ||
|
||
py_all = all | ||
|
||
|
||
class name_scope(object): | ||
def __init__(self, name): | ||
|
@@ -2030,6 +2032,10 @@ def concatenate(tensors, axis=-1): | |
|
||
# Returns | ||
A tensor. | ||
|
||
Note: | ||
- MXNet supports sparse concat only for dim=0 | ||
- https://mxnet.apache.org/api/python/symbol/sparse.html#mxnet.symbol.sparse.concat | ||
""" | ||
if axis < 0: | ||
rank = ndim(tensors[0]) | ||
|
@@ -2038,8 +2044,12 @@ def concatenate(tensors, axis=-1): | |
else: | ||
axis = 0 | ||
|
||
tensors = [t.symbol for t in tensors] | ||
return KerasSymbol(mx.sym.concat(*tensors, dim=axis)) | ||
symbols = [t.symbol for t in tensors] | ||
|
||
if axis == 0 and py_all([is_sparse(t) for t in tensors]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it is preferrable to avoid multiple returns in a func. |
||
return KerasSymbol(mx.sym.sparse.concat(*symbols, dim=axis)) | ||
|
||
return KerasSymbol(mx.sym.concat(*symbols, dim=axis)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should show warning if axis !=0 and is_sparse(t) for t in tensors? I am assuming you do dense concat automatically if axis !=0. This has performance implication and important for users to know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is replicating the behavior for sparse concat in mxnet - http://mxnet.apache.org/api/python/symbol/symbol.html#mxnet.symbol.concat - do we still want to add a warning there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think yes. Users should know that tensor is being converted to dense. |
||
|
||
|
||
@keras_mxnet_symbol | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to define py_all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following the convention in
tensorflow_backend
file