Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather_op with python op passed #3540

Merged
merged 15 commits into from
Aug 22, 2017
Merged

Conversation

zchen0211
Copy link
Contributor

No description provided.

@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_CPU_ONLY_OP(gather);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why USE_CPU_ONLY_OP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU test and cuda context not very ready. Easy to add it later, maybe next time :)


protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto *X_grad
auto *X

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done changed.

class GatherOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto *X

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Changed.

distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Copyright的格式貌似和别的不太一样

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every our Cuda .cu file uses this, while every .cc file uses another one. Slightly different. Which should we use?

Copy link
Member

@jacquesqiao jacquesqiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for some code style problem

@@ -43,6 +43,9 @@ endfunction()

add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
# DEPS op_registry)
# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove needless lines. Don't comment them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Good catch. Already removed :)

void InferShape(const framework::InferShapeContext &ctx) const override {
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
Copy link
Collaborator

@JiayiFeng JiayiFeng Aug 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework::DDim is enough. Prefix paddle:: can be removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd better use more PADDLE_ENFORCE check's to forbid user's misuse.

: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Y", "The output of add op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we'd better use 'Out' as name here.
see the draft rule Output. we want to unify the operator input/output name for clear.

void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto Index = ctx.Input<Tensor>("Index");
auto Y = ctx.Output<Tensor>("Y");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here, better use auto* instead of auto. which indicates that it is a pointer. we have a discussion on it in our Hi group.

'Index': numpy.array([1, 3, 5]).astype("int32")
}
self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here needs the check gradient of Gather op.

auto Index = ctx.Input<Tensor>("Index");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dX, dY need to get mutable_data to allocate memory space. right?

int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
Copy link
Collaborator

@JiayiFeng JiayiFeng Aug 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd better do some checks before here, the largest element in Index must be less than ctx.Input<Tensor>("X")->dims()[0].

Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM++

@zchen0211 zchen0211 merged commit a0aa907 into PaddlePaddle:develop Aug 22, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants