Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elu gelu relu logsigmoid, test=develop #26304

Merged
merged 9 commits into from
Aug 19, 2020
5 changes: 1 addition & 4 deletions python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def thresholded_relu(x, threshold=None):
_gelu_ = generate_layer_fn('gelu')


@deprecated(since="2.0.0", update_to="paddle.nn.functional.gelu")
def gelu(x, approximate=False):
locals_var = locals().copy()
kwargs = dict()
Expand All @@ -655,10 +656,6 @@ def gelu(x, approximate=False):


gelu.__doc__ = """
:alias_main: paddle.nn.functional.gelu
:alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu
:old_api: paddle.fluid.layers.gelu

:strong:`GeLU Activation Operator`
For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).

Expand Down
273 changes: 183 additions & 90 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = np.log(1 / (1 + np.exp(-x)))

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}

def test_check_grad(self):
Expand All @@ -127,6 +127,48 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.008)


class TestLogSigmoidAPI(unittest.TestCase):
# test paddle.nn.LogSigmoid, paddle.nn.functional.logsigmoid
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [11, 17])
out1 = F.logsigmoid(x)
m = paddle.nn.LogSigmoid()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = np.log(1 / (1 + np.exp(-self.x_np)))
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.logsigmoid(x)
m = paddle.nn.LogSigmoid()
out2 = m(x)
out_ref = np.log(1 / (1 + np.exp(-self.x_np)))
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.logsigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.logsigmoid, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16')
F.logsigmoid(x_fp16)


class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
Expand Down Expand Up @@ -644,7 +686,7 @@ def setUp(self):
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}

def test_check_grad(self):
Expand All @@ -653,18 +695,46 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReluOpError(unittest.TestCase):
class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = np.maximum(self.x_np, 0)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
out_ref = np.maximum(self.x_np, 0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.relu, 1)
self.assertRaises(TypeError, F.relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu, x_int32)
x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu(x_fp16)
x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16')
F.relu(x_fp16)


class TestLeakyRelu(TestActivation):
Expand Down Expand Up @@ -717,7 +787,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}

Expand All @@ -735,7 +805,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}

Expand All @@ -745,6 +815,55 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestGELUAPI(unittest.TestCase):
# test paddle.nn.GELU, paddle.nn.functional.gelu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [11, 17])
out1 = F.gelu(x)
m = paddle.nn.GELU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = gelu(self.x_np, False)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.gelu(x)
m = paddle.nn.GELU()
out2 = m(x)
out_ref = gelu(self.x_np, False)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.gelu(x, True)
m = paddle.nn.GELU(True)
out2 = m(x)
out_ref = gelu(self.x_np, True)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.gelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.gelu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16')
F.gelu(x_fp16)


class TestBRelu(TestActivation):
def setUp(self):
self.op_type = "brelu"
Expand Down Expand Up @@ -894,14 +1013,19 @@ def test_errors(self):
fluid.layers.soft_relu(x_fp16)


def elu(x, alpha):
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
return out_ref.astype(x.dtype)


class TestELU(TestActivation):
def setUp(self):
self.op_type = "elu"
self.init_dtype()

x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
alpha = 1.
out = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
out = elu(x, alpha)
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self.inputs = {'X': x}
Expand All @@ -914,16 +1038,53 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestELUOpError(unittest.TestCase):
class TestELUAPI(unittest.TestCase):
# test paddle.nn.ELU, paddle.nn.functional.elu
def setUp(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.elu(x)
m = paddle.nn.ELU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = elu(self.x_np, 1.0)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.elu(x)
m = paddle.nn.ELU()
out2 = m(x)
out_ref = elu(self.x_np, 1.0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.elu(x, 0.2)
m = paddle.nn.ELU(0.2)
out2 = m(x)
out_ref = elu(self.x_np, 0.2)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with program_guard(Program(), Program()):
# The input type of elu_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.elu, x1)
# The input dtype of elu_op must be float16 float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.elu, x2)
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.elu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.elu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16')
F.elu(x_fp16)


class TestReciprocal(TestActivation):
Expand Down Expand Up @@ -1422,73 +1583,5 @@ def test_check_grad(self):
create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish)


class TestNNReluAPI(unittest.TestCase):
def setUp(self):
self.init_data()

def init_data(self):
self.x_shape = [10, 12]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)

def ref_forward(self, x):
return np.maximum(x, 0)

def ref_backward(self, y, dy):
y_t = y.copy()
y_t[y_t > 0] = 1
return y_t * dy

def check_api(self, place=fluid.CPUPlace()):
main_program = Program()
myrelu = nn.ReLU()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = myrelu(x)
fluid.backward.append_backward(fluid.layers.mean(y))
exe = fluid.Executor(place)
out = exe.run(main_program,
feed={'x': self.x},
fetch_list=[y, y.grad_name, x.grad_name])
self.assertTrue(np.allclose(out[0], self.y))
self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1])))

with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = myrelu(x)
self.assertTrue(np.allclose(y.numpy(), self.y))

def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_api(place)


class TestNNFunctionalReluAPI(unittest.TestCase):
def setUp(self):
self.init_data()

def init_data(self):
self.x_shape = [10, 12]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)

def ref_forward(self, x):
return np.maximum(x, 0)

def test_check_api(self):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = F.relu(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@
from .decode import gather_tree #DEFINE_ALIAS
from .input import data #DEFINE_ALIAS
# from .input import Input #DEFINE_ALIAS
from .layer.activation import ELU
from .layer.activation import GELU
from .layer.activation import Hardshrink
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
from .layer.activation import ReLU
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import LogSigmoid
# from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS
Expand Down
Loading