-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add GPU operator: sparse_fill_empty_rows and sparse_reshape #87
Conversation
385de86
to
f8d71f1
Compare
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.h
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc
Outdated
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc
Show resolved
Hide resolved
tensorflow_recommenders_addons/dynamic_embedding/core/ops/math_ops.cc
Outdated
Show resolved
Hide resolved
225a2aa
to
a9d93b1
Compare
@@ -115,4 +115,65 @@ REGISTER_OP("TFRA>SparseSegmentSumWithNumSegments") | |||
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); | |||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM | |||
|
|||
#if GOOGLE_CUDA | |||
REGISTER_OP("TFRA>SparseFillEmptyRows") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls replace ‘TFRA>’ by 'Tfra', for bad compatiblity with TF1.x
.Device(DEVICE_GPU) \ | ||
.TypeConstraint<type>("T"), \ | ||
SparseFillEmptyRowsOp<GPUDevice, type>) | ||
TF_CALL_int32(REGISTER_KERNELS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need float, half, int32 and int8 to work with GPU hash table
#define DEFINE_GPU_KERNELS(type) \ | ||
template struct SparseFillEmptyRowsFunctor<GPUDevice, type>; | ||
|
||
TF_CALL_int32(DEFINE_GPU_KERNELS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same problem
}; | ||
|
||
#if GOOGLE_CUDA | ||
REGISTER_KERNEL_BUILDER(Name("TFRA>SparseReshape").Device(DEVICE_GPU), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tfra
return Status::OK(); | ||
}); | ||
|
||
REGISTER_OP("TFRA>SparseReshape") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tfra
def sparse_fill_empty_rows(sp_input, default_value, name=None): | ||
"""Fills empty rows in the input 2-D `SparseTensor` with a default value. | ||
|
||
It do same things as `tf.sparse.fill_empty_rows`. Here we provide GPU impl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does
|
||
It do same things as `tf.sparse.fill_empty_rows`. Here we provide GPU impl. | ||
|
||
Go [tf api](https://www.tensorflow.org/api_docs/python/tf/sparse/fill_empty_rows) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF API
Args: | ||
sp_input: A `SparseTensor` with shape `[N, M]`. | ||
default_value: The value to fill for empty rows, with the same type as | ||
`sp_input.` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sp_input.
->sp_input
|
||
It do same things as `tf.sparse.reshape`. Here we provide GPU impl. | ||
|
||
Go [tf api](https://www.tensorflow.org/api_docs/python/tf/sparse/reshape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF API
def sparse_reshape(sp_input, shape, name=None): | ||
"""Reshapes a `SparseTensor` to represent values in a new dense shape. | ||
|
||
It do same things as `tf.sparse.reshape`. Here we provide GPU impl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implement
try: | ||
return _sparse_reshape_gpu(sp_input, shape, name=name) | ||
except errors.NotFoundError: | ||
tf_logging.warn('`tfra.dynamic_embedding.sparse_reshape` is not' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this log will be printed too frequently on eager mode ? @Lifann
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It only happens when trying to run the TFRA CPU version on GPU enabled environment. It reminds someone to install the GPU version when it is enabled. But you're right, it could be reconsidered to get better warning info.
@@ -23,15 +23,23 @@ custom_op_library( | |||
) | |||
|
|||
custom_op_library( | |||
name = "_segment_reduction_ops.so", | |||
name = "_math_ops.so", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the so name, and also change the readme.
af8c11e
to
7177565
Compare
LGTM |
Description
Add GPU operator: sparse_fill_empty_rows and sparse_reshape
Fixes # (issue)
Type of change
Checklist:
How Has This Been Tested?
If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*