Skip to content

Commit

Permalink
Merge pull request #4604 from kavyasrinet/activations
Browse files Browse the repository at this point in the history
Added Leaky Relu activation
  • Loading branch information
kavyasrinet authored Oct 5, 2017
2 parents 828c5b3 + ba9a0bc commit 3e2be06
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
19 changes: 19 additions & 0 deletions paddle/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

template <typename AttrType>
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddComment(
"LeakyRelu activation operator, "
"leaky_relu = max(x, alpha * x)");
AddAttr<AttrType>("alpha", "The small negative slope")
.SetDefault(static_cast<AttrType>(0.02f));
}
};

class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
Expand Down Expand Up @@ -240,6 +256,9 @@ REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);

REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
leaky_relu_grad, ops::ActivationOpGrad);

REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);

Expand Down
30 changes: 29 additions & 1 deletion paddle/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,33 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(alpha * x);
}
};

template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = alpha * (x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};

template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
Expand Down Expand Up @@ -379,4 +406,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor)
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor)
17 changes: 17 additions & 0 deletions python/paddle/v2/framework/tests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)


class TestLeakyRelu(OpTest):
def setUp(self):
self.op_type = "leaky_relu"
alpha = 0.02
self.attrs = {'alpha': alpha}
self.inputs = {'X': np.random.uniform(-3, 3, [4, 4]).astype("float32")}
self.outputs = {
'Y': np.maximum(self.inputs['X'], alpha * self.inputs['X'])
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)


class TestSoftRelu(OpTest):
def setUp(self):
self.op_type = "soft_relu"
Expand Down

0 comments on commit 3e2be06

Please sign in to comment.