Skip to content

Commit

Permalink
elu gelu relu logsigmoid, test=develop (#26304)
Browse files Browse the repository at this point in the history
* logsigmoid and LogSigmoid, test=develop

* add elu gelu relu, test=develop

* update to_variable to to_tensor, test=develop

* address review comments, test=develop

* address review comments, test=develop

* change to_variable to to_tensor in test, test=develop
  • Loading branch information
qili93 authored Aug 19, 2020
1 parent 0a461ac commit 61800f4
Show file tree
Hide file tree
Showing 5 changed files with 503 additions and 151 deletions.
5 changes: 1 addition & 4 deletions python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def thresholded_relu(x, threshold=None):
_gelu_ = generate_layer_fn('gelu')


@deprecated(since="2.0.0", update_to="paddle.nn.functional.gelu")
def gelu(x, approximate=False):
locals_var = locals().copy()
kwargs = dict()
Expand All @@ -655,10 +656,6 @@ def gelu(x, approximate=False):


gelu.__doc__ = """
:alias_main: paddle.nn.functional.gelu
:alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu
:old_api: paddle.fluid.layers.gelu
:strong:`GeLU Activation Operator`
For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
Expand Down
273 changes: 183 additions & 90 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = np.log(1 / (1 + np.exp(-x)))

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}

def test_check_grad(self):
Expand All @@ -127,6 +127,48 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.008)


class TestLogSigmoidAPI(unittest.TestCase):
# test paddle.nn.LogSigmoid, paddle.nn.functional.logsigmoid
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [11, 17])
out1 = F.logsigmoid(x)
m = paddle.nn.LogSigmoid()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = np.log(1 / (1 + np.exp(-self.x_np)))
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.logsigmoid(x)
m = paddle.nn.LogSigmoid()
out2 = m(x)
out_ref = np.log(1 / (1 + np.exp(-self.x_np)))
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.logsigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.logsigmoid, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16')
F.logsigmoid(x_fp16)


class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
Expand Down Expand Up @@ -644,7 +686,7 @@ def setUp(self):
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}

def test_check_grad(self):
Expand All @@ -653,18 +695,46 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestReluOpError(unittest.TestCase):
class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = np.maximum(self.x_np, 0)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
out_ref = np.maximum(self.x_np, 0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.relu, 1)
self.assertRaises(TypeError, F.relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu, x_int32)
x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu(x_fp16)
x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16')
F.relu(x_fp16)


class TestLeakyRelu(TestActivation):
Expand Down Expand Up @@ -717,7 +787,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}

Expand All @@ -735,7 +805,7 @@ def setUp(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}

Expand All @@ -745,6 +815,55 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestGELUAPI(unittest.TestCase):
# test paddle.nn.GELU, paddle.nn.functional.gelu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [11, 17])
out1 = F.gelu(x)
m = paddle.nn.GELU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = gelu(self.x_np, False)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.gelu(x)
m = paddle.nn.GELU()
out2 = m(x)
out_ref = gelu(self.x_np, False)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.gelu(x, True)
m = paddle.nn.GELU(True)
out2 = m(x)
out_ref = gelu(self.x_np, True)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.gelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.gelu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16')
F.gelu(x_fp16)


class TestBRelu(TestActivation):
def setUp(self):
self.op_type = "brelu"
Expand Down Expand Up @@ -894,14 +1013,19 @@ def test_errors(self):
fluid.layers.soft_relu(x_fp16)


def elu(x, alpha):
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
return out_ref.astype(x.dtype)


class TestELU(TestActivation):
def setUp(self):
self.op_type = "elu"
self.init_dtype()

x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
alpha = 1.
out = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
out = elu(x, alpha)
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self.inputs = {'X': x}
Expand All @@ -914,16 +1038,53 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestELUOpError(unittest.TestCase):
class TestELUAPI(unittest.TestCase):
# test paddle.nn.ELU, paddle.nn.functional.elu
def setUp(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.elu(x)
m = paddle.nn.ELU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = elu(self.x_np, 1.0)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.elu(x)
m = paddle.nn.ELU()
out2 = m(x)
out_ref = elu(self.x_np, 1.0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.elu(x, 0.2)
m = paddle.nn.ELU(0.2)
out2 = m(x)
out_ref = elu(self.x_np, 0.2)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with program_guard(Program(), Program()):
# The input type of elu_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.elu, x1)
# The input dtype of elu_op must be float16 float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.elu, x2)
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.elu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.elu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16')
F.elu(x_fp16)


class TestReciprocal(TestActivation):
Expand Down Expand Up @@ -1422,73 +1583,5 @@ def test_check_grad(self):
create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish)


class TestNNReluAPI(unittest.TestCase):
def setUp(self):
self.init_data()

def init_data(self):
self.x_shape = [10, 12]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)

def ref_forward(self, x):
return np.maximum(x, 0)

def ref_backward(self, y, dy):
y_t = y.copy()
y_t[y_t > 0] = 1
return y_t * dy

def check_api(self, place=fluid.CPUPlace()):
main_program = Program()
myrelu = nn.ReLU()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = myrelu(x)
fluid.backward.append_backward(fluid.layers.mean(y))
exe = fluid.Executor(place)
out = exe.run(main_program,
feed={'x': self.x},
fetch_list=[y, y.grad_name, x.grad_name])
self.assertTrue(np.allclose(out[0], self.y))
self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1])))

with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = myrelu(x)
self.assertTrue(np.allclose(y.numpy(), self.y))

def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_api(place)


class TestNNFunctionalReluAPI(unittest.TestCase):
def setUp(self):
self.init_data()

def init_data(self):
self.x_shape = [10, 12]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.y = self.ref_forward(self.x)

def ref_forward(self, x):
return np.maximum(x, 0)

def test_check_api(self):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = F.relu(x)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], self.y))


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@
from .decode import gather_tree #DEFINE_ALIAS
from .input import data #DEFINE_ALIAS
# from .input import Input #DEFINE_ALIAS
from .layer.activation import ELU
from .layer.activation import GELU
from .layer.activation import Hardshrink
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
from .layer.activation import ReLU
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import LogSigmoid
# from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import LogSoftmax #DEFINE_ALIAS
from .layer.activation import HSigmoid #DEFINE_ALIAS
Expand Down
Loading

0 comments on commit 61800f4

Please sign in to comment.