From 1be84c710961d2507464ab5dc05e4aef308513f6 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Wed, 29 Nov 2023 13:53:30 +0000 Subject: [PATCH 01/11] triangular_solve complex support --- .../cpu/triangular_solve_grad_kernel.cc | 2 +- .../kernels/cpu/triangular_solve_kernel.cc | 2 +- paddle/phi/kernels/funcs/matrix_reduce.cc | 2 + paddle/phi/kernels/funcs/matrix_reduce.cu | 2 + .../gpu/triangular_solve_grad_kernel.cu | 2 +- .../kernels/gpu/triangular_solve_kernel.cu | 2 +- python/paddle/tensor/linalg.py | 8 +-- test/legacy_test/test_triangular_solve_op.py | 52 +++++++++++++++++-- 8 files changed, 60 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc index 80b2015f7318a..d8b78f34565db 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc @@ -20,4 +20,4 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc index 06c897b219984..a768dd5562c19 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc @@ -82,4 +82,4 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cc b/paddle/phi/kernels/funcs/matrix_reduce.cc index 34d84070497fc..db2588c50f650 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cc +++ b/paddle/phi/kernels/funcs/matrix_reduce.cc @@ -54,6 +54,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, CPUContext>; +template class MatrixReduceSumFunctor, CPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index 5c3ebd6bb0167..83f3aa0c154d9 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, GPUContext>; +template class MatrixReduceSumFunctor, GPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu index f7eaa48579794..397b39b7e37be 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu @@ -20,4 +20,4 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu index 889c421eb0bb9..d8b8ca2a92f2c 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu @@ -128,4 +128,4 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, phi::dtype::complex, phi::dtype::complex) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5088cea790fd2..c8ae80532fc44 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3186,9 +3186,9 @@ def triangular_solve( Args: x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or - more batch dimensions. Its data type should be float32 or float64. + more batch dimensions. Its data type should be float32, float64, complex64, complex128. y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is - zero or more batch dimensions. Its data type should be float32 or float64. + zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128. upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True. transpose (bool, optional): whether `x` should be transposed before calculation. Default: False. @@ -3227,10 +3227,10 @@ def triangular_solve( inputs = {"X": [x], "Y": [y]} helper = LayerHelper("triangular_solve", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'triangular_solve' + x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'triangular_solve' ) check_variable_and_dtype( - y, 'y', ['float32', 'float64'], 'triangular_solve' + y, 'y', ['float32', 'float64', 'complex64', 'complex128'], 'triangular_solve' ) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index f3624b5332817..3efe8deffca6f 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -51,10 +51,17 @@ def setUp(self): self.python_api = paddle.tensor.linalg.triangular_solve self.config() - self.inputs = { - 'X': np.random.random(self.x_shape).astype(self.dtype), - 'Y': np.random.random(self.y_shape).astype(self.dtype), - } + if self.dtype is np.complex64 or self.dtype is np.complex128: + self.inputs = { + 'X': (np.random.random(self.x_shape) + 1j * np.random.random(self.x_shape)).astype(self.dtype), + 'Y': (np.random.random(self.y_shape) + 1j * np.random.random(self.y_shape)).astype(self.dtype), + } + else: + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': np.random.random(self.y_shape).astype(self.dtype), + } + self.attrs = { 'upper': self.upper, 'transpose': self.transpose, @@ -247,7 +254,44 @@ def set_output(self): y = self.inputs['Y'] self.output = np.matmul(np.linalg.inv(x), y) +# 3D(broadcast) + 3D complex64 +class TestTriangularSolveOpCp64(TestTriangularSolveOp): + """ + case complex64 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "complex64" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + +# 3D(broadcast) + 3D complex128 +class TestTriangularSolveCp128(TestTriangularSolveOp): + """ + case complex128 + """ + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "complex128" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + class TestTriangularSolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021) From ea76c11a2718cbca4471b3bbb780d12e75d9ec43 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Wed, 29 Nov 2023 14:04:35 +0000 Subject: [PATCH 02/11] triangular_solve complex support --- .../kernels/cpu/triangular_solve_grad_kernel.cc | 4 +++- paddle/phi/kernels/cpu/triangular_solve_kernel.cc | 4 +++- .../kernels/gpu/triangular_solve_grad_kernel.cu | 4 +++- paddle/phi/kernels/gpu/triangular_solve_kernel.cu | 4 +++- python/paddle/tensor/linalg.py | 10 ++++++++-- test/legacy_test/test_triangular_solve_op.py | 15 ++++++++++++--- 6 files changed, 32 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc index d8b78f34565db..95e96b6d7918c 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double, phi::dtype::complex, phi::dtype::complex) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc index a768dd5562c19..23ae786bd4c68 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double, phi::dtype::complex, phi::dtype::complex) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu index 397b39b7e37be..67861b282529b 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double, phi::dtype::complex, phi::dtype::complex) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu index d8b8ca2a92f2c..9b4b8f1a68d5e 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double, phi::dtype::complex, phi::dtype::complex) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index ad16ec36f2d2f..faf1925f48ef3 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3227,10 +3227,16 @@ def triangular_solve( inputs = {"X": [x], "Y": [y]} helper = LayerHelper("triangular_solve", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'triangular_solve' + x, + 'x', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) check_variable_and_dtype( - y, 'y', ['float32', 'float64', 'complex64', 'complex128'], 'triangular_solve' + y, + 'y', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index 3efe8deffca6f..03f252f1f44c3 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -53,8 +53,14 @@ def setUp(self): if self.dtype is np.complex64 or self.dtype is np.complex128: self.inputs = { - 'X': (np.random.random(self.x_shape) + 1j * np.random.random(self.x_shape)).astype(self.dtype), - 'Y': (np.random.random(self.y_shape) + 1j * np.random.random(self.y_shape)).astype(self.dtype), + 'X': ( + np.random.random(self.x_shape) + + 1j * np.random.random(self.x_shape) + ).astype(self.dtype), + 'Y': ( + np.random.random(self.y_shape) + + 1j * np.random.random(self.y_shape) + ).astype(self.dtype), } else: self.inputs = { @@ -254,6 +260,7 @@ def set_output(self): y = self.inputs['Y'] self.output = np.matmul(np.linalg.inv(x), y) + # 3D(broadcast) + 3D complex64 class TestTriangularSolveOpCp64(TestTriangularSolveOp): """ @@ -273,6 +280,7 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + # 3D(broadcast) + 3D complex128 class TestTriangularSolveCp128(TestTriangularSolveOp): """ @@ -291,7 +299,8 @@ def set_output(self): x = np.tril(self.inputs['X']) y = self.inputs['Y'] self.output = np.linalg.solve(x, y) - + + class TestTriangularSolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021) From d6d3b55a7e24455cce62787c68b2a018431ffe24 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Thu, 30 Nov 2023 15:03:59 +0000 Subject: [PATCH 03/11] fix bug --- paddle/phi/kernels/funcs/blas/blas_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.h b/paddle/phi/kernels/funcs/blas/blas_impl.h index ffafe15b8fcf2..b4ee437011f66 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.h @@ -877,7 +877,7 @@ struct CBlas> { const phi::dtype::complex alpha, const phi::dtype::complex *A, const int lda, - phi::dtype::complex *B, + phi::dtype::complex *B, const int ldb) { cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); } From dab410901301389b8647f21d8c28756f32ef0806 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Mon, 4 Dec 2023 12:26:18 +0000 Subject: [PATCH 04/11] add complex dtype case --- test/legacy_test/test_triangular_solve_op.py | 218 ++++++++++++++++++- 1 file changed, 214 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index 03f252f1f44c3..769ac32799cfc 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -262,9 +262,9 @@ def set_output(self): # 3D(broadcast) + 3D complex64 -class TestTriangularSolveOpCp64(TestTriangularSolveOp): +class TestTriangularSolveOpCp643b3(TestTriangularSolveOp): """ - case complex64 + case 10 """ def config(self): @@ -273,13 +273,223 @@ def config(self): self.upper = False self.transpose = False self.unitriangular = False - self.dtype = "complex64" + self.dtype = np.complex64 def set_output(self): x = np.tril(self.inputs['X']) y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 2D + 2D upper complex64 +class TestTriangularSolveOpCp6422Up(TestTriangularSolveOp): + """ + case 11 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex64 +class TestTriangularSolveOpCp6423T(TestTriangularSolveOp): + """ + case 12 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 2D + 2D , test 'unitriangular' complex64 +class TestTriangularSolveOpCp6422Un(TestTriangularSolveOp): + """ + case 13 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + ) + + +# 4D(broadcast) + 4D(broadcast) complex64 +class TestTriangularSolveOpCp644b4b(TestTriangularSolveOp): + """ + case 14 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex64 +class TestTriangularSolveOpCp643b4bUp(TestTriangularSolveOp): + """ + case 15 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 3D(broadcast) + 5D complex64 +class TestTriangularSolveOpCp643b5(TestTriangularSolveOp): + """ + case 16 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + + +# 5D + 4D(broadcast) complex64 +class TestTriangularSolveOpCp6454b(TestTriangularSolveOp): + """ + case 17 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + max_relative_error=0.05, + ) + # 3D(broadcast) + 3D complex128 class TestTriangularSolveCp128(TestTriangularSolveOp): @@ -293,7 +503,7 @@ def config(self): self.upper = False self.transpose = False self.unitriangular = False - self.dtype = "complex128" + self.dtype = np.complex128 def set_output(self): x = np.tril(self.inputs['X']) From 7ad9317cf6a7fbaac67dc003e52e7a430211ea8d Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Wed, 6 Dec 2023 03:19:52 +0000 Subject: [PATCH 05/11] add complex128 case & fix max_relative_error --- test/legacy_test/test_triangular_solve_op.py | 216 ++++++++++++++++++- 1 file changed, 207 insertions(+), 9 deletions(-) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index 769ac32799cfc..5c88e087e584b 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -286,7 +286,6 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, ) @@ -315,7 +314,7 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, + max_relative_error=0.02, ) @@ -344,7 +343,6 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, ) @@ -400,7 +398,7 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, + max_relative_error=0.008, ) @@ -429,7 +427,6 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, ) @@ -458,7 +455,6 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, ) @@ -487,14 +483,13 @@ def test_check_grad_normal(self): 'Out', check_cinn=True, check_pir=True, - max_relative_error=0.05, ) # 3D(broadcast) + 3D complex128 -class TestTriangularSolveCp128(TestTriangularSolveOp): +class TestTriangularSolveOpCp1283b3(TestTriangularSolveOp): """ - case complex128 + case 18 """ def config(self): @@ -510,6 +505,209 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 2D + 2D upper complex128 +class TestTriangularSolveOpCp12822Up(TestTriangularSolveOp): + """ + case 19 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex128 +class TestTriangularSolveOpCp12823T(TestTriangularSolveOp): + """ + case 20 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 2D + 2D , test 'unitriangular' complex128 +class TestTriangularSolveOpCp12822Un(TestTriangularSolveOp): + """ + case 21 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + ) + + +# 4D(broadcast) + 4D(broadcast) complex128 +class TestTriangularSolveOpCp1284b4b(TestTriangularSolveOp): + """ + case 22 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex128 +class TestTriangularSolveOpCp1283b4bUp(TestTriangularSolveOp): + """ + case 23 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 3D(broadcast) + 5D complex128 +class TestTriangularSolveOpCp1283b5(TestTriangularSolveOp): + """ + case 24 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + + +# 5D + 4D(broadcast) complex128 +class TestTriangularSolveOpCp12854b(TestTriangularSolveOp): + """ + case 25 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_cinn=True, + check_pir=True, + ) + class TestTriangularSolveAPI(unittest.TestCase): def setUp(self): From bda2de6b1b87fe89a90fcdf69cffc2f9da825845 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Wed, 6 Dec 2023 13:32:20 +0000 Subject: [PATCH 06/11] fix cinn bugs --- paddle/fluid/framework/paddle2cinn/transform_desc.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc index b61a7be544fd2..7a869021ec9da 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -92,6 +92,8 @@ ::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP64); + SET_DATA_TYPE_CASE_ITEM(COMPLEX64); + SET_DATA_TYPE_CASE_ITEM(COMPLEX128); default: PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); } From c94bf35f7fd0956c413541d12291a1155d6c1cd5 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Thu, 7 Dec 2023 02:40:55 +0000 Subject: [PATCH 07/11] fix cinn bug --- paddle/fluid/framework/paddle2cinn/transform_desc.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc index 7a869021ec9da..b61a7be544fd2 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -92,8 +92,6 @@ ::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP64); - SET_DATA_TYPE_CASE_ITEM(COMPLEX64); - SET_DATA_TYPE_CASE_ITEM(COMPLEX128); default: PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); } From dcc9c1b69db8ae006d44688738f79f1cbda80713 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Thu, 7 Dec 2023 15:36:50 +0000 Subject: [PATCH 08/11] fix bug --- paddle/fluid/framework/paddle2cinn/transform_desc.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc index b61a7be544fd2..7a869021ec9da 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -92,6 +92,8 @@ ::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP64); + SET_DATA_TYPE_CASE_ITEM(COMPLEX64); + SET_DATA_TYPE_CASE_ITEM(COMPLEX128); default: PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); } From e28882c5ee9c561193bb17f6315ccb79c25787ac Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Thu, 21 Dec 2023 15:33:50 +0000 Subject: [PATCH 09/11] fix cinn --- paddle/fluid/framework/paddle2cinn/transform_desc.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc index 7a869021ec9da..b61a7be544fd2 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -92,8 +92,6 @@ ::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP64); - SET_DATA_TYPE_CASE_ITEM(COMPLEX64); - SET_DATA_TYPE_CASE_ITEM(COMPLEX128); default: PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); } From bc14276dc9af8e7047baf4a18e1df76f2f326ca1 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Fri, 22 Dec 2023 04:01:04 +0000 Subject: [PATCH 10/11] remove check_cinn=True for complex dtype --- test/legacy_test/test_triangular_solve_op.py | 62 +++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index 5c88e087e584b..8a5ea1cd4fca1 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -280,11 +280,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -308,11 +310,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, max_relative_error=0.02, ) @@ -337,11 +341,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -366,6 +372,9 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], @@ -392,11 +401,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, max_relative_error=0.008, ) @@ -421,11 +432,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -449,11 +462,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -477,11 +492,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.matmul(np.linalg.inv(x), y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -505,11 +522,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -533,11 +552,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -561,11 +582,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -590,6 +613,9 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], @@ -616,11 +642,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -644,11 +672,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -672,11 +702,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.linalg.solve(x, y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) @@ -700,11 +732,13 @@ def set_output(self): y = self.inputs['Y'] self.output = np.matmul(np.linalg.inv(x), y) + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', - check_cinn=True, check_pir=True, ) From bf834564ca944ccd96cd48983a1696e68c5ca8f4 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Sat, 23 Dec 2023 05:54:07 +0000 Subject: [PATCH 11/11] fix --- test/legacy_test/test_triangular_solve_op.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index 8a5ea1cd4fca1..d4aecda8780ce 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -376,10 +376,7 @@ def test_check_output(self): self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y'], - 'Out', - ) + self.check_grad(['X', 'Y'], 'Out') # 4D(broadcast) + 4D(broadcast) complex64