Skip to content

Commit

Permalink
add prelu, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
qili93 committed Aug 19, 2020
1 parent fef0dae commit d426d38
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 56 deletions.
133 changes: 122 additions & 11 deletions python/paddle/fluid/tests/unittests/test_prelu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,134 @@
import numpy as np
import paddle.fluid as fluid
import six
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.nn.functional as F


def ref_prelu(x, weight):
x_t = x.copy()
weight = weight.reshape(1, -1, 1, 1)
neg_indices = x <= 0
assert x.shape == neg_indices.shape
x_t[neg_indices] = (x_t * weight)[neg_indices]
return (x_t, )


def ref_prelu_nn(x, num_parameters, init):
weight_np = np.full((num_parameters), init)
return ref_prelu(x, weight_np)


class TestPReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
class TestFunctionalPReluAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float32')
self.weight_np_0 = np.random.randn(1).astype('float32')
self.weight_np_1 = np.random.randn(self.x_np.shape[1]).astype('float32')

def static_check(self, weight_np):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, 'float32')
weight = paddle.data('Alpha', weight_np.shape, 'float32')
out = F.prelu(x, weight)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np,
'Alpha': weight_np},
fetch_list=[out])
out_ref = ref_prelu(self.x_np, weight_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)

def dygraph_check(self, weight_np):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
weight = paddle.to_tensor(weight_np)
out = F.prelu(x, weight)
out_ref = ref_prelu(self.x_np, weight_np)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)
paddle.enable_static()

def test_static_api(self):
self.static_check(self.weight_np_0)
self.static_check(self.weight_np_1)

def test_dygraph_api(self):
self.dygraph_check(self.weight_np_0)
self.dygraph_check(self.weight_np_1)

def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
weight_fp32 = paddle.data(
name='weight_fp32', shape=[1], dtype='float32')
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.prelu, 0.1, 'all')
self.assertRaises(TypeError, F.prelu, x=1, weight=weight_fp32)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.prelu, x_int32, 'all')
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.prelu(x_fp16, 'all')
x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16')
F.prelu(x=x_fp16, weight=weight_fp32)


class TestNNPReluAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.ones([1, 2, 3, 4]).astype('float32')

def test_static_api(self):
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
x = paddle.data(name='X', shape=self.x_np.shape, dtype='float32')
m = paddle.nn.PReLU()
out = m(x)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
res = exe.run(train_program,
feed={'X': self.x_np},
fetch_list=[out])
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, res[0]), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU()
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(num_parameters=self.x_np.shape[1])
out = m(x)
out_ref = ref_prelu_nn(self.x_np, self.x_np.shape[1], 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(init=0.5)
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.5)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr(name="weight"))
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.5)))
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.5)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

paddle.enable_static()


class PReluTest(OpTest):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .layer.activation import GELU
from .layer.activation import Hardshrink
from .layer.activation import HardTanh
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import PReLU
from .layer.activation import ReLU
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .activation import leaky_relu #DEFINE_ALIAS
from .activation import logsigmoid #DEFINE_ALIAS
from .activation import maxout #DEFINE_ALIAS
# from .activation import prelu #DEFINE_ALIAS
from .activation import prelu #DEFINE_ALIAS
from .activation import relu #DEFINE_ALIAS
from .activation import relu6 #DEFINE_ALIAS
from .activation import selu #DEFINE_ALIAS
Expand Down
95 changes: 76 additions & 19 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'leaky_relu',
'logsigmoid',
'maxout',
# 'prelu',
'prelu',
'relu',
'relu6',
'selu',
Expand All @@ -56,7 +56,7 @@
'swish',
'tanh_shrink',
'thresholded_relu',
'log_softmax'
'log_softmax',
]

import warnings
Expand All @@ -73,7 +73,7 @@ def elu(x, alpha=1.0, name=None):
.. math::
elu(x) = max(0, x) + min(0, \\alpha * (e^{x}-1))
elu(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
Expand Down Expand Up @@ -119,10 +119,10 @@ def gelu(x, approximate=False, name=None):
if approximate is True
.. math::
gelu(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3})))
gelu(x) = 0.5 * x * (1 + tanh(\sqrt{\frac{2}{\pi}} * (x + 0.044715x^{3})))
else
.. math::
gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}}))
gelu(x) = 0.5 * x * (1 + erf(\frac{x}{\sqrt{2}}))
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
Expand All @@ -142,17 +142,9 @@ def gelu(x, approximate=False, name=None):
paddle.disable_static()
data = np.random.randn(2, 3).astype("float32")
x = paddle.to_tensor(data)
out = F.gelu(x)
data
# array([[ 0.87165993, -1.0541513 , -0.37214822],
# [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32)
out
# array([[ 0.70456535, -0.15380788, -0.13207214],
# [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32)
x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]]))
out1 = F.gelu(x) # [-0.158655 0.345731 0.841345 1.39979]
out2 = F.gelu(x, True) # [-0.158808 0.345714 0.841192 1.39957]
"""

if in_dygraph_mode():
Expand Down Expand Up @@ -203,7 +195,7 @@ def hardshrink(x, threshold=0.5, name=None):
paddle.disable_static()
x = paddle.to_variable(np.array([-1, 0.3, 2.5]))
x = paddle.to_tensor(np.array([-1, 0.3, 2.5]))
out = F.hardshrink(x) # [-1., 0., 2.5]
"""
Expand Down Expand Up @@ -401,6 +393,71 @@ def hsigmoid(input,
return out


def prelu(x, weight, name=None):
"""
prelu activation.
.. math::
prelu(x) = max(0, x) + \weight * min(0, x)
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
weight (Tensor): The learnable parameter with data type same as ``x``.
The weight shape is [1] or [in], where `in` is the input channel of ``x``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([[-1,6],[1,15.6]]))
weight = paddle.to_tensor(np.array([0.2]))
out = F.prelu(x, weight)
# [[-0.12642411 6. ]
# [ 1. 15.6 ]]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
check_variable_and_dtype(weight, 'weight',
['float16', 'float32', 'float64'], 'prelu')

helper = LayerHelper('prelu', **locals())
assert len(weight.shape
) == 1, "The dim count of weight shape should be 1 in prelu()."

# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
mode = 'all'
if weight.shape[0] > 1:
assert len(
x.shape
) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
assert weight.shape[0] == x.shape[
1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
mode = 'channel'

if in_dygraph_mode():
return core.ops.prelu(x, weight, 'mode', mode)

out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type="prelu",
inputs={"X": x,
"Alpha": weight},
outputs={"Out": out},
attrs={"mode": mode})
return out


def relu(x, name=None):
"""
ReLU Activation.
Expand Down Expand Up @@ -507,7 +564,7 @@ def softmax(x, axis=-1, dtype=None, name=None):
.. math::
softmax[i, j] = \\frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}
softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}
Example:
Expand Down Expand Up @@ -650,7 +707,7 @@ def log_softmax(x, axis=-1, dtype=None, name=None):
.. math::
Out[i, j] = log(softmax(x))
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
= log(\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
Expand Down
Loading

0 comments on commit d426d38

Please sign in to comment.