From bcd1b16c65c0c10aab3c94d423d72240436239c9 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 22 Sep 2023 18:05:09 +0800 Subject: [PATCH 1/3] complex sqrt op --- .../phi/kernels/cpu/activation_grad_kernel.cc | 6 +- paddle/phi/kernels/cpu/activation_kernel.cc | 2 +- paddle/phi/kernels/funcs/activation_functor.h | 72 +++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 5 +- paddle/phi/kernels/gpu/activation_kernel.cu | 2 +- test/legacy_test/test_activation_op.py | 26 ++++++- 6 files changed, 104 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d3cf1cbcb34c1..b23aebff6550a 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -305,7 +305,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) @@ -316,8 +316,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, - SqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad, + SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 66480018a5273..041a18b30947d 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -199,7 +199,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) -PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6b77c31d38d4a..3aea976c5610f 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -695,6 +695,24 @@ struct SqrtGradFunctor : public BaseActivationFunctor { } }; +template +struct SqrtGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * (static_cast>(0.5) / out).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { @@ -2728,6 +2746,44 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { } }; +template +struct SqrtGradGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(const Device& dev, + const DenseTensor* Out, + const DenseTensor* dX, + const DenseTensor* ddX, + DenseTensor* dOut, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector>::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); + auto out = EigenVector>::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); + // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx + // calculate dy first, so ddy can inplace ddx + if (dOut) { + auto dx = EigenVector>::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); + auto dout = EigenVector>::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); + dout.device(*d) = + dx * ddx * + (static_cast>(-1) / out).unaryExpr(Conj()); + } + if (ddOut) { + auto ddout = EigenVector>::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); + ddout.device(*d) = + ddx * (static_cast>(0.5) / out).unaryExpr(Conj()); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct RsqrtGradGradFunctor : public BaseActivationFunctor { template @@ -3659,6 +3715,22 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaSqrtGradFunctor> + : public BaseActivationFunctor> { + ComplexType one_half = static_cast>(0.5f); + + // dx = dout * 0.5 / out + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return one_half * dout / out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index d592dfad0a52d..66f26988adbb0 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -384,8 +384,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, SoftplusDoubleGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad, + SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 000428268bbb1..b92aecff359f6 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -248,7 +248,7 @@ PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) -PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8b74a2900502b..1ea601dd6f789 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -1391,7 +1391,13 @@ def setUp(self): self.if_enable_cinn() np.random.seed(1023) - x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype is np.complex64 or self.dtype is np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) + else: + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) out = np.sqrt(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -1404,12 +1410,28 @@ def if_enable_cinn(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True + if self.dtype not in [np.complex64, np.complex128] + else False, + ) def test_check_output(self): self.check_output() +class TestSqrtComplex64(TestSqrt): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestSqrtComplex128(TestSqrt): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSqrtPrimFp32(TestActivation): def setUp(self): self.op_type = "sqrt" From 07d71827dc468d964b57631e7b22daa12adf8e55 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 28 Sep 2023 14:15:36 +0800 Subject: [PATCH 2/3] fix --- .../phi/kernels/cpu/activation_grad_kernel.cc | 4 +- paddle/phi/kernels/funcs/activation_functor.h | 41 +------------------ .../phi/kernels/gpu/activation_grad_kernel.cu | 3 +- 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b23aebff6550a..de604336a33f8 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -316,8 +316,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad, - SqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, + SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 3aea976c5610f..975efc7304cc2 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2628,7 +2628,6 @@ struct SwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { float factor; @@ -2746,44 +2745,6 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { } }; -template -struct SqrtGradGradFunctor> - : public BaseActivationFunctor> { - template - void operator()(const Device& dev, - const DenseTensor* Out, - const DenseTensor* dX, - const DenseTensor* ddX, - DenseTensor* dOut, - DenseTensor* ddOut) const { - auto* d = dev.eigen_device(); - auto ddx = EigenVector>::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); - auto out = EigenVector>::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); - // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx - // calculate dy first, so ddy can inplace ddx - if (dOut) { - auto dx = EigenVector>::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); - auto dout = EigenVector>::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); - dout.device(*d) = - dx * ddx * - (static_cast>(-1) / out).unaryExpr(Conj()); - } - if (ddOut) { - auto ddout = EigenVector>::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); - ddout.device(*d) = - ddx * (static_cast>(0.5) / out).unaryExpr(Conj()); - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct RsqrtGradGradFunctor : public BaseActivationFunctor { template @@ -3723,7 +3684,7 @@ struct CudaSqrtGradFunctor> // dx = dout * 0.5 / out __device__ __forceinline__ ComplexType operator()( const ComplexType dout, const ComplexType out) const { - return one_half * dout / out; + return dout * conj(one_half / out); } static constexpr ActBwdOpFwdDeps FwdDeps() { diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 66f26988adbb0..0706411ab5d6b 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -385,8 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, SoftplusDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad, - SqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) From 2e9afecde6f78230b2dff99b5c92c93992bd888e Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 8 Nov 2023 14:44:34 +0800 Subject: [PATCH 3/3] fix codestyle --- python/paddle/tensor/ops.py | 9 +++++++- test/legacy_test/op_test.py | 4 +++- test/legacy_test/test_activation_op.py | 29 +++++++++++++++----------- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index e3fec0ee5fc84..140f53f038c72 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -1063,7 +1063,14 @@ def sqrt(x, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'uint16', 'float32', 'float64'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], 'sqrt', ) helper = LayerHelper('sqrt', **locals()) diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index b181e549eed7c..c89851de13b38 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -1732,7 +1732,9 @@ def _dfs_grad_op(op_desc, fwd_op_desc=None): has_infer_inplace = base.core.has_infer_inplace(op_desc.type()) has_grad_op_maker = base.core.has_grad_op_maker(op_desc.type()) has_infer_inplace_in_grad_descendants = False - if not has_grad_op_maker: + # the OP test doesn't support higher order grad + is_grad_op_desc = op_desc.type().endswith('_grad') + if not has_grad_op_maker or is_grad_op_desc: has_infer_inplace_in_descendants = False else: # get grad_op_desc diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index b1d77a535ceb8..9de59e4f3c4ef 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -1523,21 +1523,26 @@ def test_check_grad(self): if self.dtype == np.float16: return if self.dtype not in [np.complex64, np.complex128]: - self.check_grad( - ['X'], - 'Out', - check_prim=True, - check_pir=True, - check_prim_pir=True, - ) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) else: - self.check_grad( - ['X'], - 'Out', - ) + self.check_grad( + ['X'], + 'Out', + ) def test_check_output(self): - self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) + if self.dtype not in [np.complex64, np.complex128]: + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) + else: + self.check_output() class TestSqrtComplex64(TestSqrt):