Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine elementwise_op #8091

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddle/operators/compare_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ class CompareOpKernel
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
using Tensor = framework::Tensor;

auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
z);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -92,9 +99,19 @@ template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
ElementwiseAddBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -111,9 +118,19 @@ template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
ElementwiseDivBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_max_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>,
ElementwiseMaxBroadCastGradFunctor<T>,
ElementwiseMaxBroadCast2GradFunctor<T>>(ctx);
ElementwiseMaxBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_min_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMinGradFunctor<T>,
ElementwiseMinBroadCastGradFunctor<T>,
ElementwiseMinBroadCast2GradFunctor<T>>(ctx);
ElementwiseMinBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
ElementwiseMulBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down
28 changes: 10 additions & 18 deletions paddle/operators/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,21 +313,18 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV);

template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,

const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, int axis,
framework::Tensor* dx, framework::Tensor* dy) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();

auto x_dims = x->dims();
auto y_dims = y->dims();

auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
Expand All @@ -348,7 +345,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
x_dims = framework::make_ddim(extended_dims);
}

int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);

int pre, n, post;
Expand All @@ -367,13 +363,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {

template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<OutType>(ctx.GetPlace());
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, int axis,
framework::Tensor* z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor());

Expand All @@ -394,7 +387,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
x_dims = framework::make_ddim(extended_dims);
}

int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
Expand Down
9 changes: 8 additions & 1 deletion paddle/operators/elementwise_pow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ template <typename DeviceContext, typename T>
class ElementwisePowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down
21 changes: 19 additions & 2 deletions paddle/operators/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx);
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
}
};

Expand Down Expand Up @@ -93,9 +100,19 @@ template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
ElementwiseSubBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
}
};

Expand Down