Skip to content

Commit

Permalink
Merge pull request #3540 from zchen0211/develop
Browse files Browse the repository at this point in the history
Gather_op with python op passed
  • Loading branch information
zchen0211 authored Aug 22, 2017
2 parents 5810d63 + 0a0f194 commit a0aa907
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 3 deletions.
1 change: 1 addition & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
sgd_op
gather_op
add_op
mul_op
rowwise_add_op
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_CPU_ONLY_OP(gather);

namespace paddle {
namespace framework {
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ endfunction()

add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)

cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)

Expand Down
7 changes: 4 additions & 3 deletions paddle/operators/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <cstring>

#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"

Expand All @@ -25,13 +26,13 @@ namespace operators {

// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
void CPUGather(const T* src, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);

for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
}

Expand All @@ -55,7 +56,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
int index_size = index->dims()[0];

auto src_dims = src->dims();
paddle::framework::DDim output_dims(src_dims);
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
Expand Down
72 changes: 72 additions & 0 deletions paddle/operators/gather_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/gather_op.h"
#include "paddle/framework/ddim.h"

namespace paddle {
namespace operators {

class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Out")->Resize(output_dims);
}
};

class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");

X_grad->Resize(X->dims());
}
};

class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Gather Operator by selecting from the first axis,
Out = X[Index]
)DOC");
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
20 changes: 20 additions & 0 deletions paddle/operators/gather_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/operators/gather_op.h"

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::GPUPlace, float>);
53 changes: 53 additions & 0 deletions paddle/operators/gather_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "gather.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "scatter.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index");
auto *Y = ctx.Output<Tensor>("Out");

Y->mutable_data<T>(ctx.GetPlace());
Gather<T>(ctx.GetPlace(), X, Index, Y);
}
};

template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));

dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
}
};

} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)

py_test(gradient_checker SRCS gradient_checker.py)
Expand Down
34 changes: 34 additions & 0 deletions python/paddle/v2/framework/tests/test_gather_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator


class TestGatherOp(unittest.TestCase):
__metaclass__ = OpTestMeta

def setUp(self):
self.type = "gather"
xnp = numpy.random.random((10, 20)).astype("float32")
self.inputs = {
'X': xnp,
'Index': numpy.array([1, 3, 5]).astype("int32")
}
self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]}


class TestGatherGradOp(GradientChecker):
def test_gather_grad(self):
print 'creating op'
op = create_op("gather")
print 'creating op done'
xnp = numpy.random.random((10, 20)).astype("float32")
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
print 'correct before check gradient'
self.check_grad(op, inputs, set("X"), "Out")


if __name__ == "__main__":
unittest.main()

0 comments on commit a0aa907

Please sign in to comment.