-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add clip op #3937
Add clip op #3937
Changes from 3 commits
987cdf1
320df7a
6e964ad
2321a37
a345b71
3102a52
44224f4
a3c3b78
14fb15b
743dfd8
2cde56c
3f3848c
1fdad1a
c7b6d2c
1244050
9569255
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/clip_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::Tensor; | ||
|
||
class ClipOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto max = Attr<float>("max"); | ||
auto min = Attr<float>("min"); | ||
PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); | ||
ctx.Output<Tensor>("Out")->Resize(x_dims); | ||
} | ||
}; | ||
|
||
class ClipOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The input of clip op"); | ||
AddOutput("Out", "The output of clip op"); | ||
AddComment(R"DOC( | ||
Clip Operator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An better comments example: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
)DOC"); | ||
AddAttr<float>("min", "min value to be clipped."); | ||
AddAttr<float>("max", "max value to be clipped."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AddAttr放在DOC前面吧,min、max也需要类型模板,参考: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L36 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 一个op的多个attrs的类型必须一样么? |
||
} | ||
}; | ||
|
||
class ClipOpGrad : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | ||
"Input(Out@GRAD) should not be null"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check is also needed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Output< framework::LoDTensor> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
x_grad->Resize(x_dims); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker, clip_grad, ops::ClipOpGrad); | ||
REGISTER_OP_CPU_KERNEL(clip, | ||
ops::ClipKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel<float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#define EIGEN_USE_GPU | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
#include "paddle/operators/clip_op.h" | ||
|
||
#define CUDA_1D_KERNEL_LOOP(i, n) \ | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ | ||
i += blockDim.x * gridDim.x) | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
__global__ void ClipGradientKernel(const int N, const T min, const T max, | ||
const T* Y, const T* dY, T* dX) { | ||
CUDA_1D_KERNEL_LOOP(i, N) { dX[i] = dY[i] * (Y[i] > min && Y[i] < max); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation depends on |
||
} | ||
|
||
template <typename T> | ||
class ClipGradientOpCUDAKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto max = context.op().Attr<float>("max"); | ||
auto min = context.op().Attr<float>("min"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); | ||
auto* x = context.Output<Tensor>("X"); | ||
auto dims = d_x->dims(); | ||
size_t count = 1; | ||
for (int i = 0; i < dims.size(); ++i) { | ||
count *= dims[i]; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
auto d_x_data = d_x->mutable_data<T>(context.GetPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean |
||
auto d_out_data = d_out->data<T>(); | ||
auto x_data = x->data<T>(); | ||
|
||
int N = d_x->dims()[0]; | ||
int D = d_x->dims()[1]; | ||
int block = 512; | ||
int grid = (N * D + block - 1) / block; | ||
|
||
ClipGradientKernel<T><<<grid, block>>>(count, min, max, x_data, d_out_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to use CUDA stream when launching kernel. Please refer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
d_x_data); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_GPU_KERNEL(clip, | ||
ops::ClipKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradientOpCUDAKernel<float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T, size_t D, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; | ||
|
||
template <typename Place, typename T> | ||
class ClipKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto max = context.op().Attr<float>("max"); | ||
auto min = context.op().Attr<float>("min"); | ||
auto* x = context.Input<Tensor>("X"); | ||
auto* out = context.Output<Tensor>("Out"); | ||
out->mutable_data<T>(context.GetPlace()); | ||
auto x_tensor = EigenTensor<T, 2>::From(*x); | ||
auto out_tensor = EigenTensor<T, 2>::From(*out); | ||
auto place = context.GetEigenDevice<Place>(); | ||
out_tensor.device(place) = x_tensor.cwiseMin(max).cwiseMax(min); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class ClipGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto max = context.op().Attr<float>("max"); | ||
auto min = context.op().Attr<float>("min"); | ||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); | ||
auto* x = context.Output<Tensor>("X"); | ||
auto dims = d_x->dims(); | ||
size_t count = 1; | ||
for (int i = 0; i < dims.size(); ++i) { | ||
count *= dims[i]; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
auto d_x_data = d_x->mutable_data<T>(context.GetPlace()); | ||
auto d_out_data = d_out->data<T>(); | ||
auto x_data = x->data<T>(); | ||
for (int i = 0; i < count; ++i) { | ||
d_x_data[i] = d_out_data[i] * (x_data[i] > min && x_data[i] < max); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import unittest | ||
import numpy as np | ||
from paddle.v2.framework.op import Operator | ||
from gradient_checker import GradientChecker | ||
from op_test_util import OpTestMeta | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the new unit testing framework. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
|
||
class TestClipOp(unittest.TestCase): | ||
__metaclass__ = OpTestMeta | ||
|
||
def setUp(self): | ||
input = np.random.random((16, 16)).astype("float32") | ||
print "input: %s" % input | ||
self.type = "clip" | ||
self.inputs = {'X': input, } | ||
self.attrs = {} | ||
self.attrs['min'] = 0.1 | ||
self.attrs['max'] = 0.9 | ||
self.outputs = { | ||
'Out': np.clip(self.inputs['X'], self.attrs['min'], | ||
self.attrs['max']) | ||
} | ||
|
||
|
||
class TestClipGradOp(GradientChecker): | ||
def setUp(self): | ||
self.op = Operator(type="clip", X="X", Out="Out", min=0.1, max=0.9) | ||
self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } | ||
|
||
def test_normal(self): | ||
self.check_grad( | ||
self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5) | ||
|
||
def test_cpu_gpu_compare(self): | ||
self.compare_grad(self.op, self.inputs) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output< framework::LoDTensor>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.