From 2e5c01361ce8550742723a18a9676c25cece08fd Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 23 Aug 2023 00:38:13 +0800 Subject: [PATCH 1/2] add complex dtype register for matrix_power --- paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc | 4 +++- paddle/phi/kernels/cpu/matrix_power_kernel.cc | 10 ++++++++-- paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu | 4 +++- paddle/phi/kernels/gpu/matrix_power_kernel.cu | 10 ++++++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc b/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc index 0f60f8da71a8b..c6d44f6625263 100644 --- a/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(matrix_power_grad, ALL_LAYOUT, phi::MatrixPowerGradKernel, float, - double) {} + double, + phi::dtype::complex64, + phi::dtype::complex128) {} diff --git a/paddle/phi/kernels/cpu/matrix_power_kernel.cc b/paddle/phi/kernels/cpu/matrix_power_kernel.cc index 08ee7cbc865df..c0aabd77562fe 100644 --- a/paddle/phi/kernels/cpu/matrix_power_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_power_kernel.cc @@ -18,5 +18,11 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h" -PD_REGISTER_KERNEL( - matrix_power, CPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {} +PD_REGISTER_KERNEL(matrix_power, + CPU, + ALL_LAYOUT, + phi::MatrixPowerKernel, + float, + double, + phi::dtype::complex64, + phi::dtype::complex128) {} diff --git a/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu b/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu index 25a9de8f8bed4..e9f1e6bba8529 100644 --- a/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(matrix_power_grad, ALL_LAYOUT, phi::MatrixPowerGradKernel, float, - double) {} + double, + phi::dtype::complex64, + phi::dtype::complex128) {} diff --git a/paddle/phi/kernels/gpu/matrix_power_kernel.cu b/paddle/phi/kernels/gpu/matrix_power_kernel.cu index d7ae7d8a3f745..0c4de7ce2df56 100644 --- a/paddle/phi/kernels/gpu/matrix_power_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_power_kernel.cu @@ -18,5 +18,11 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h" -PD_REGISTER_KERNEL( - matrix_power, GPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {} +PD_REGISTER_KERNEL(matrix_power, + GPU, + ALL_LAYOUT, + phi::MatrixPowerKernel, + float, + double, + phi::dtype::complex64, + phi::dtype::complex128) {} From 176b2bf668888ac82a025ba8e65f44f0f09ec5e8 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 23 Aug 2023 00:48:38 +0800 Subject: [PATCH 2/2] add complex64/128 ut for matrix_power --- test/legacy_test/test_matrix_power_op.py | 76 ++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/test/legacy_test/test_matrix_power_op.py b/test/legacy_test/test_matrix_power_op.py index cc4be16fdfaf9..bd51d00c52a21 100644 --- a/test/legacy_test/test_matrix_power_op.py +++ b/test/legacy_test/test_matrix_power_op.py @@ -240,6 +240,82 @@ def config(self): self.n = -1 +class TestMatrixPowerOpCP64(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex64" + self.n = 2 + + def test_grad(self): + self.check_grad(["X"], "Out", max_relative_error=1e-5) + + +class TestMatrixPowerOpBatchedCP64(TestMatrixPowerOpCP64): + def config(self): + self.matrix_shape = [2, 8, 4, 4] + self.dtype = "complex64" + self.n = 2 + + +class TestMatrixPowerOpLarge1CP64(TestMatrixPowerOpCP64): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "complex64" + self.n = 2 + + +class TestMatrixPowerOpLarge2CP64(TestMatrixPowerOpCP64): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex64" + self.n = 32 + + +class TestMatrixPowerOpCP64Minus(TestMatrixPowerOpCP64): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex64" + self.n = -1 + + +class TestMatrixPowerOpCP128(TestMatrixPowerOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex128" + self.n = 2 + + def test_grad(self): + self.check_grad(["X"], "Out", max_relative_error=1e-5) + + +class TestMatrixPowerOpBatchedCP128(TestMatrixPowerOpCP128): + def config(self): + self.matrix_shape = [2, 8, 4, 4] + self.dtype = "complex128" + self.n = 2 + + +class TestMatrixPowerOpLarge1CP128(TestMatrixPowerOpCP128): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "complex128" + self.n = 2 + + +class TestMatrixPowerOpLarge2CP128(TestMatrixPowerOpCP128): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex128" + self.n = 32 + + +class TestMatrixPowerOpCP128Minus(TestMatrixPowerOpCP128): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex128" + self.n = -1 + + class TestMatrixPowerAPI(unittest.TestCase): def setUp(self): np.random.seed(123)