diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d3cf1cbcb34c1..d32cd9d78c8cb 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -434,7 +434,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, + HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 66480018a5273..4c562095dfc59 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -278,7 +278,7 @@ PD_REGISTER_KERNEL(log1p, phi::dtype::float16, phi::dtype::bfloat16) {} -PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6b77c31d38d4a..6071ec95a1878 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2574,7 +2574,42 @@ struct HardSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct HardSwishGradFunctor> + : public BaseActivationFunctor> { + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto offset_t = static_cast>(offset); + auto threshold_t = static_cast>(threshold); + auto one = static_cast>(1); + auto zero = static_cast>(0); + auto two = static_cast>(2); + auto scale_t = static_cast>(scale); + auto tmp1 = ((x + offset_t) < threshold_t) // NOLINT + .template cast>(); + auto tmp2 = ((x + offset_t) > zero).template cast>(); + // dx = 0, when x <= -offset + // dout , when x >= threshold - offset + // dout * (2 * x / scale + offset / scale), otherwise + // threshold = scale = 6, offset = 3 by default + dx.device(d) = dout * (tmp2 * (two * x + offset_t) / scale_t * tmp1 + + one * (one - tmp1)) + .unaryExpr(Conj()); + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; template struct SwishFunctor : public BaseActivationFunctor { float beta; @@ -4540,8 +4575,10 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor { // threshold = scale = 6, offset = 3 by default __device__ __forceinline__ T operator()(const T x) const { const MPType x_t = static_cast(x); - const MPType temp_max = std::max(x_t + static_cast(offset), zero); - const MPType temp_min = std::min(temp_max, static_cast(threshold)); + const MPType x_offset_t = x_t + static_cast(offset); + const MPType temp_max = (x_offset_t >= zero) ? x_offset_t : zero; + const MPType threshold_t = static_cast(threshold); + const MPType temp_min = (temp_max < threshold_t) ? temp_max : threshold_t; return static_cast(temp_min * x_t / static_cast(scale)); } }; @@ -4581,6 +4618,43 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaHardSwishGradFunctor> + : public BaseActivationFunctor> { + const ComplexType zero = static_cast>(0.0f); + const ComplexType one = static_cast>(1.0f); + const ComplexType two = static_cast>(2.0f); + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + // dx = 0, when x <= -offset + // dout , when x >= threshold - offset + // dout * (2 * x / scale + offset / scale), otherwise + // threshold = scale = 6, offset = 3 by default + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + const ComplexType dout_t = static_cast>(dout); + const ComplexType x_t = static_cast>(x); + const ComplexType offset_t = static_cast>(offset); + const ComplexType scale_t = static_cast>(scale); + const ComplexType temp1 = + static_cast>(x_t + offset_t > zero); + const ComplexType temp2 = static_cast>( + x_t + offset_t < static_cast>(threshold)); + + return static_cast>( + dout_t * + conj(temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaCeilFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index d592dfad0a52d..f7ac36ccfe213 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -513,7 +513,8 @@ PD_REGISTER_KERNEL(log_double_grad, double, phi::dtype::float16, phi::dtype::bfloat16) {} -PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, + HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 000428268bbb1..0c23d8c6989f1 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -296,7 +296,7 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index e02a47d7bf8dd..f8d97e9503306 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -422,7 +422,17 @@ def hardswish(x, name=None): return _C_ops.hardswish(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], + 'hardswish', ) threshold = 6.0 diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8b74a2900502b..b83e8c6b1316a 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2734,7 +2734,13 @@ def setUp(self): self.public_python_api = paddle.nn.functional.hardswish np.random.seed(1024) - x = np.random.uniform(-6, 6, self.shape).astype(self.dtype) + if self.dtype is np.complex64 or self.dtype is np.complex128: + x = ( + np.random.uniform(-6, 6, self.shape) + + 1j * np.random.uniform(-6, 6, self.shape) + ).astype(self.dtype) + else: + x = np.random.uniform(-6, 6, self.shape).astype(self.dtype) threshold = 6.0 scale = 6.0 offset = 3.0 @@ -2758,12 +2764,18 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', - check_prim=True, + check_prim=True + if self.dtype not in [np.complex64, np.complex128] + else False, only_check_prim=self.if_only_check_prim(), ) def test_check_output(self): - self.check_output(check_prim=True) + self.check_output( + check_prim=True + if self.dtype not in [np.complex64, np.complex128] + else False + ) class TestHardSwish_ZeroDim(TestHardSwish): @@ -2771,6 +2783,16 @@ def init_shape(self): self.shape = [] +class TestHardSwishComplex64(TestHardSwish): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestHardSwishComplex128(TestHardSwish): + def init_dtype(self): + self.dtype = np.complex128 + + class TestHardswishAPI(unittest.TestCase): # test paddle.nn.Hardswish, paddle.nn.functional.hardswish def setUp(self):