diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 6816a353ce5042..65bde5601128f8 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -307,7 +307,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, ReluDoubleGradKernel) @@ -320,8 +321,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 813a7ffc7ba422..a8169df1021d2b 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -201,7 +201,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, CPU, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index c50194bfaf009a..b2c2d493c48ad3 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -799,6 +799,31 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SoftplusGradFunctor> + : public BaseActivationFunctor> { + float beta; + float threshold; + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto x_beta = static_cast>(beta) * x; // NOLINT + dx.device(d) = + (x_beta > static_cast>(threshold)) + .select(dout, + dout / (static_cast>(1) + (-x_beta).exp()) + .unaryExpr(Conj())); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct SoftplusDoubleGradFunctor : public BaseActivationFunctor { float beta; @@ -3681,7 +3706,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); - MPType x_beta = x * beta; + MPType x_beta = x * static_cast(beta); return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; @@ -3711,6 +3736,34 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSoftplusGradFunctor> + : public BaseActivationFunctor> { + using MPType = typename phi::dtype::MPTypeTrait>::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * static_cast(beta); + return x_beta > t + ? dout + : static_cast>(dout / conj(one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index b65eaa5d7757d1..c67864bc13f573 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -381,9 +381,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index acfe4dd5a2941b..6eeba717ece0dd 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -250,7 +250,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, GPU, diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 7acafa290f7e0b..c74748793a4e9d 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1277,7 +1277,7 @@ def softplus(x, beta=1, threshold=20, name=None): \end{cases} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type float32, float64, complex64, complex128. beta (float, optional): The value of :math:`\beta` for softplus. Default is 1 threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20 name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. @@ -1302,7 +1302,17 @@ def softplus(x, beta=1, threshold=20, name=None): return _C_ops.softplus(x, beta, threshold) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'softplus' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], + 'softplus', ) helper = LayerHelper('softplus', **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f4f56d049524a3..12fd5ae8f09a5a 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3877,6 +3877,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = ref_softplus(x, beta, threshold) self.inputs = {'X': x} self.attrs = {'beta': beta, "threshold": threshold} @@ -3891,6 +3896,19 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestSoftplus_Complex64(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.06) + + +class TestSoftplus_Complex128(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSoftplus_ZeroDim(TestSoftplus): def init_shape(self): self.shape = []