-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gather_op with python op passed #3540
Conversation
@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); | |||
USE_OP_ITSELF(recurrent_op); | |||
USE_OP(gaussian_random); | |||
USE_OP(uniform_random); | |||
USE_CPU_ONLY_OP(gather); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why USE_CPU_ONLY_OP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU test and cuda context not very ready. Easy to add it later, maybe next time :)
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto *X_grad
auto *X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Done changed.
paddle/operators/gather_op.h
Outdated
class GatherOpKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto X = ctx.Input<Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto *X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Changed.
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Copyright的格式貌似和别的不太一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every our Cuda .cu file uses this, while every .cc file uses another one. Slightly different. Which should we use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for some code style problem
paddle/operators/CMakeLists.txt
Outdated
@@ -43,6 +43,9 @@ endfunction() | |||
|
|||
add_subdirectory(math) | |||
cc_test(gather_test SRCS gather_test.cc DEPS tensor) | |||
op_library(gather_op SRCS gather_op.cc gather_op.cu) | |||
# DEPS op_registry) | |||
# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove needless lines. Don't comment them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Good catch. Already removed :)
paddle/operators/gather_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
int batch_size = ctx.Input<Tensor>("Index")->dims()[0]; | ||
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); | ||
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
framework::DDim
is enough. Prefix paddle::
can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'd better use more PADDLE_ENFORCE
check's to forbid user's misuse.
paddle/operators/gather_op.cc
Outdated
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The source input of gather op"); | ||
AddInput("Index", "The index input of gather op"); | ||
AddOutput("Y", "The output of add op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we'd better use 'Out' as name here.
see the draft rule Output. we want to unify the operator input/output name for clear.
paddle/operators/gather_op.h
Outdated
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto X = ctx.Input<Tensor>("X"); | ||
auto Index = ctx.Input<Tensor>("Index"); | ||
auto Y = ctx.Output<Tensor>("Y"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here, better use auto*
instead of auto
. which indicates that it is a pointer. we have a discussion on it in our Hi group.
'Index': numpy.array([1, 3, 5]).astype("int32") | ||
} | ||
self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here needs the check gradient of Gather
op.
paddle/operators/gather_op.h
Outdated
auto Index = ctx.Input<Tensor>("Index"); | ||
auto dX = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y")); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dX
, dY
need to get mutable_data
to allocate memory space. right?
paddle/operators/gather_op.cc
Outdated
int batch_size = ctx.Input<Tensor>("Index")->dims()[0]; | ||
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); | ||
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims()); | ||
output_dims[0] = batch_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd better do some checks before here, the largest element in Index
must be less than ctx.Input<Tensor>("X")->dims()[0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM++
No description provided.