From fbd7a89f12f1aa807ce0dae3ffa0f48e72fc3b83 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 5 Nov 2023 11:00:16 +0800 Subject: [PATCH 1/5] Add tests --- paconvert/api_mapping.json | 206 ++++++++++++++++++++++-- paconvert/api_matcher.py | 48 ++++++ tests/test_Tensor_cholesky_solve.py | 67 ++++++++ tests/test_Tensor_float_power.py | 53 ++++++ tests/test_Tensor_float_power_.py | 53 ++++++ tests/test_Tensor_symeig.py | 52 ++++++ tests/test_Tensor_triangular_solve.py | 79 +++++++++ tests/test_autograd_enable_grad.py | 46 ++++++ tests/test_autograd_set_grad_enabled.py | 57 +++++++ tests/test_linalg_eig.py | 75 +++++++++ tests/test_linalg_eigh.py | 31 ++++ tests/test_linalg_eigvalsh.py | 77 +++++++++ tests/test_linalg_matrix_norm.py | 51 +++--- tests/test_linalg_pinv.py | 77 +++++++++ tests/test_linalg_solve.py | 85 ++++++++++ tests/test_linalg_svd.py | 115 +++++++++++++ tests/test_linalg_vector_norm.py | 51 +++--- tests/test_pca_lowrank.py | 77 +++++++++ tests/test_symeig.py | 77 +++++++++ 19 files changed, 1324 insertions(+), 53 deletions(-) create mode 100644 tests/test_Tensor_cholesky_solve.py create mode 100644 tests/test_Tensor_float_power.py create mode 100644 tests/test_Tensor_float_power_.py create mode 100644 tests/test_Tensor_symeig.py create mode 100644 tests/test_Tensor_triangular_solve.py create mode 100644 tests/test_autograd_enable_grad.py create mode 100644 tests/test_autograd_set_grad_enabled.py create mode 100644 tests/test_linalg_eig.py create mode 100644 tests/test_linalg_eigh.py create mode 100644 tests/test_linalg_eigvalsh.py create mode 100644 tests/test_linalg_pinv.py create mode 100644 tests/test_linalg_solve.py create mode 100644 tests/test_linalg_svd.py create mode 100644 tests/test_pca_lowrank.py create mode 100644 tests/test_symeig.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 9e7fbb164..e7317b10b 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -593,7 +593,17 @@ "upper" ] }, - "torch.Tensor.cholesky_solve": {}, + "torch.Tensor.cholesky_solve": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.cholesky_solve", + "args_list": [ + "input2", + "upper" + ], + "kwargs_change": { + "input2": "y" + } + }, "torch.Tensor.chunk": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.chunk", @@ -1125,8 +1135,18 @@ "memory_format" ] }, - "torch.Tensor.float_power": {}, - "torch.Tensor.float_power_": {}, + "torch.Tensor.float_power": { + "Matcher": "FloatPowerMatcher", + "args_list": [ + "exponent" + ] + }, + "torch.Tensor.float_power_": { + "Matcher": "FloatPowerInplaceMatcher", + "args_list": [ + "exponent" + ] + }, "torch.Tensor.floor": { "Matcher": "UnchangeMatcher" }, @@ -2893,7 +2913,14 @@ "dim1" ] }, - "torch.Tensor.symeig": {}, + "torch.Tensor.symeig": { + "Matcher": "SymeigMatcher", + "paddle_api": "paddle.linalg.eigh", + "args_list": [ + "eigenvectors", + "upper" + ] + }, "torch.Tensor.t": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.t" @@ -2977,7 +3004,16 @@ ] }, "torch.Tensor.transpose_": {}, - "torch.Tensor.triangular_solve": {}, + "torch.Tensor.triangular_solve": { + "Matcher": "TensorTriangularSolveMatcher", + "paddle_api": "paddle.linalg.triangular_solve", + "args_list": [ + "A", + "upper", + "transpose", + "unitriangular" + ] + }, "torch.Tensor.tril": { "Matcher": "TensorFunc2PaddleFunc", "paddle_api": "paddle.tril", @@ -3730,6 +3766,10 @@ "inputs" ] }, + "torch.autograd.enable_grad": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.enable_grad" + }, "torch.autograd.function.FunctionCtx": { "Matcher": "GenericMatcher", "paddle_api": "paddle.autograd.PyLayerContext" @@ -3885,6 +3925,13 @@ "torch.autograd.profiler.profile.key_averages": {}, "torch.autograd.profiler.profile.table": {}, "torch.autograd.profiler.profile.total_average": {}, + "torch.autograd.set_grad_enabled": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.set_grad_enabled", + "args_list": [ + "mode" + ] + }, "torch.backends.cuda.is_built": { "Matcher": "GenericMatcher", "paddle_api": "paddle.device.is_compiled_with_cuda" @@ -6594,18 +6641,29 @@ } }, "torch.linalg.eig": { - "Matcher": "GenericMatcher", + "Matcher": "DoubleAssignMatcher", "paddle_api": "paddle.linalg.eig", "args_list": [ - "A", + "input", "*", "out" ], "kwargs_change": { - "A": "x" + "input": "x" + } + }, + "torch.linalg.eigh": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.eigh", + "args_list": [ + "input", + "UPLO", + "out" + ], + "kwargs_change": { + "input": "x" } }, - "torch.linalg.eigh": {}, "torch.linalg.eigvals": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.eigvals", @@ -6618,7 +6676,19 @@ "input": "x" } }, - "torch.linalg.eigvalsh": {}, + "torch.linalg.eigvalsh": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.eigvalsh", + "args_list": [ + "input", + "UPLO", + "*", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.linalg.householder_product": {}, "torch.linalg.inv": { "Matcher": "GenericMatcher", @@ -6710,7 +6780,23 @@ } }, "torch.linalg.matrix_exp": {}, - "torch.linalg.matrix_norm": {}, + "torch.linalg.matrix_norm": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.norm", + "args_list": [ + "input", + "ord", + "dim", + "keepdim", + "dtype", + "out" + ], + "kwargs_change": { + "input": "x", + "ord": "p", + "dim": "axis" + } + }, "torch.linalg.matrix_power": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.matrix_power", @@ -6773,7 +6859,23 @@ "dim": "axis" } }, - "torch.linalg.pinv": {}, + "torch.linalg.pinv": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.pinv", + "args_list": [ + "input", + "atol", + "rtol", + "hermitian", + "out" + ], + "kwargs_change": { + "input": "x", + "atol": "", + "rtol": "rcond", + "hermitian": "hermitian" + } + }, "torch.linalg.qr": { "Matcher": "DoubleAssignMatcher", "paddle_api": "paddle.linalg.qr", @@ -6796,7 +6898,23 @@ "out" ] }, - "torch.linalg.solve": {}, + "torch.linalg.solve": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.solve", + "args_list": [ + "A", + "B", + "left", + "out" + ], + "kwargs_change": { + "A": "x", + "B": "y" + }, + "unsupport_args": [ + "left" + ] + }, "torch.linalg.solve_ex": {}, "torch.linalg.solve_triangular": { "Matcher": "LinalgSolveTriangularMatcher", @@ -6816,7 +6934,23 @@ "left": "transpose" } }, - "torch.linalg.svd": {}, + "torch.linalg.svd": { + "Matcher": "TripleAssignMatcher", + "paddle_api": "paddle.linalg.svd", + "args_list": [ + "A", + "full_matrices", + "driver", + "out" + ], + "kwargs_change": { + "A": "x", + "driver": "" + }, + "paddle_default_kwargs": { + "full_matrices": "True" + } + }, "torch.linalg.svdvals": { "Matcher": "LinalgSvdvalsMatcher", "paddle_api": "paddle.linalg.svd", @@ -6848,7 +6982,23 @@ } }, "torch.linalg.vecdot": {}, - "torch.linalg.vector_norm": {}, + "torch.linalg.vector_norm": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.norm", + "args_list": [ + "x", + "ord", + "dim", + "keepdim", + "dtype", + "out" + ], + "kwargs_change": { + "ord": "p", + "dim": "axis", + "input": "x" + } + }, "torch.linspace": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linspace", @@ -11352,6 +11502,19 @@ "vec2": "y" } }, + "torch.pca_lowrank": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.pca_lowrank", + "args_list": [ + "A", + "q", + "center", + "niter" + ], + "kwargs_change": { + "A": "x" + } + }, "torch.permute": { "Matcher": "GenericMatcher", "paddle_api": "paddle.transpose", @@ -12577,6 +12740,19 @@ "dim1" ] }, + "torch.symeig": { + "Matcher": "SymeigMatcher", + "paddle_api": "paddle.linalg.eigh", + "args_list": [ + "input", + "eigenvectors", + "upper", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.t": { "Matcher": "GenericMatcher", "paddle_api": "paddle.t", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index bbc2e136c..68838689d 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2724,6 +2724,12 @@ def generate_code(self, kwargs): return code +class TensorTriangularSolveMatcher(BaseMatcher): + def generate_code(self, kwargs): + kwargs["b"] = self.paddleClass + return TriangularSolveMatcher.generate_code(self, kwargs) + + class IndexAddMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: @@ -4069,6 +4075,48 @@ def generate_code(self, kwargs): return code +class SymeigMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def _CONVERT_SYMEIG(**kwargs): + out_v = kwargs.pop("out", None) + upper = kwargs.pop("upper", True) + UPLO = "U" if upper else "L" + eigenvectors = kwargs.pop("eigenvectors", False) + if not eigenvectors: + result = (paddle.linalg.eigvalsh(kwargs["input"], UPLO=UPLO), + paddle.to_tensor([], dtype=paddle.complex64)) + else: + result = paddle.linalg.eigh(kwargs["input"], UPLO=UPLO) + if out_v: + result = paddle.assign(result[0], out_v[0]), paddle.assign(result[1], out_v[1]) + return result + """ + ) + return CODE_TEMPLATE + + def generate_code(self, kwargs): + self.write_aux_code() + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + return "paddle_aux._CONVERT_SYMEIG({})".format(self.kwargs_to_str(kwargs)) + + +class FloatPowerMatcher(BaseMatcher): + def generate_code(self, kwargs): + return "{}.cast(paddle.float64).pow({})".format( + self.paddleClass, kwargs["exponent"] + ) + + +class FloatPowerInplaceMatcher(BaseMatcher): + def generate_code(self, kwargs): + return "{}.cast_(paddle.float64).pow_({})".format( + self.paddleClass, kwargs["exponent"] + ) + + class ModuleGetSubMatcher(BaseMatcher): def generate_code(self, kwargs): code = 'getattr({}, "{}")'.format( diff --git a/tests/test_Tensor_cholesky_solve.py b/tests/test_Tensor_cholesky_solve.py new file mode 100644 index 000000000..8c64d394b --- /dev/null +++ b/tests/test_Tensor_cholesky_solve.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.cholesky_solve") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + b = torch.tensor([[-0.6355, 0.9891], + [ 0.1974, 1.4706], + [-0.4115, -0.6225]]) + result = b.cholesky_solve(x, False) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5, atol=1.0e-8) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + b = torch.tensor([[-0.6355, 0.9891], + [ 0.1974, 1.4706], + [-0.4115, -0.6225]]) + result = b.cholesky_solve(input2=x, upper=True) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5, atol=1.0e-8) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + b = torch.tensor([[-0.6355, 0.9891], + [ 0.1974, 1.4706], + [-0.4115, -0.6225]]) + result = b.cholesky_solve(upper=False, input2=x) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5, atol=1.0e-8) diff --git a/tests/test_Tensor_float_power.py b/tests/test_Tensor_float_power.py new file mode 100644 index 000000000..6150dc9c5 --- /dev/null +++ b/tests/test_Tensor_float_power.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.float_power") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]]) + result = x.float_power(2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]]) + result = x.float_power(exponent=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]]) + exp = 2 + result = x.float_power(exponent=exp) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_float_power_.py b/tests/test_Tensor_float_power_.py new file mode 100644 index 000000000..50b5ec610 --- /dev/null +++ b/tests/test_Tensor_float_power_.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.float_power_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]], dtype=torch.float64) + result = x.float_power_(2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]], dtype=torch.float64) + result = x.float_power_(exponent=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]], dtype=torch.float64) + exp = 2 + result = x.float_power_(exponent=exp) + """ + ) + obj.run(pytorch_code, ["result", "x"]) diff --git a/tests/test_Tensor_symeig.py b/tests/test_Tensor_symeig.py new file mode 100644 index 000000000..b12d0ce7f --- /dev/null +++ b/tests/test_Tensor_symeig.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.symeig", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = x.symeig() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = x.symeig(eigenvectors=False, upper=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = x.symeig(False, True) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_triangular_solve.py b/tests/test_Tensor_triangular_solve.py new file mode 100644 index 000000000..550525953 --- /dev/null +++ b/tests/test_Tensor_triangular_solve.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.triangular_solve") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + b = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result1, result2 = b.triangular_solve(a) + """ + ) + obj.run(pytorch_code, ["result1", "result2"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) + b = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result1, result2 = b.triangular_solve(a, False) + """ + ) + obj.run(pytorch_code, ["result1", "result2"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) + b = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result1, result2 = b.triangular_solve(transpose=False, A=a) + """ + ) + obj.run(pytorch_code, ["result1", "result2"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) + b = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result1, result2 = b.triangular_solve(A=a, upper=True, transpose=False, unitriangular=False) + """ + ) + obj.run(pytorch_code, ["result1", "result2"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) + b = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result1, result2 = b.triangular_solve(a, True, False, False) + """ + ) + obj.run(pytorch_code, ["result1", "result2"]) diff --git a/tests/test_autograd_enable_grad.py b/tests/test_autograd_enable_grad.py new file mode 100644 index 000000000..258418ded --- /dev/null +++ b/tests/test_autograd_enable_grad.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.enable_grad") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3]) + @torch.autograd.enable_grad() + def doubler(x): + return x * 2 + with torch.no_grad(): + result = doubler(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3]) + with torch.autograd.enable_grad(): + result = x ** 2 + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_autograd_set_grad_enabled.py b/tests/test_autograd_set_grad_enabled.py new file mode 100644 index 000000000..62ec05ead --- /dev/null +++ b/tests/test_autograd_set_grad_enabled.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.set_grad_enabled") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.], requires_grad=True) + is_train = False + with torch.autograd.set_grad_enabled(is_train): + result = x * 2 + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.], requires_grad=True) + with torch.autograd.set_grad_enabled(False): + result = x * 2 + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.], requires_grad=True) + with torch.autograd.set_grad_enabled(mode=False): + result = x * 2 + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_linalg_eig.py b/tests/test_linalg_eig.py new file mode 100644 index 000000000..ed0b638c1 --- /dev/null +++ b/tests/test_linalg_eig.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.eig") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + result = torch.linalg.eig(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + result = torch.linalg.eig(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = [torch.tensor([], dtype=torch.complex64), + torch.rand([3, 3]).to(dtype=torch.complex64)] + result = torch.linalg.eig(x, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = [torch.tensor([], dtype=torch.complex64), + torch.rand([3, 3]).to(dtype=torch.complex64)] + result = torch.linalg.eig(input=x, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_eigh.py b/tests/test_linalg_eigh.py new file mode 100644 index 000000000..007b99f82 --- /dev/null +++ b/tests/test_linalg_eigh.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.eigh") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.ones((2, 2), dtype=torch.complex128) + A = A + A.T.conj() # creates a Hermitian matrix + result = torch.linalg.eigh(A)[0] + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-7) diff --git a/tests/test_linalg_eigvalsh.py b/tests/test_linalg_eigvalsh.py new file mode 100644 index 000000000..9c1d5360b --- /dev/null +++ b/tests/test_linalg_eigvalsh.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.eigvalsh") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = torch.linalg.eigvalsh(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = torch.linalg.eigvalsh(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = torch.tensor([]) + result = torch.linalg.eigvalsh(x, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = torch.tensor([]) + result = torch.linalg.eigvalsh(x, UPLO="L", out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = torch.tensor([]) + result = torch.linalg.eigvalsh(x, "L", out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_matrix_norm.py b/tests/test_linalg_matrix_norm.py index 293c1df5d..f45af61b4 100644 --- a/tests/test_linalg_matrix_norm.py +++ b/tests/test_linalg_matrix_norm.py @@ -29,12 +29,7 @@ def test_case_1(): result = torch.linalg.matrix_norm(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -44,15 +39,10 @@ def test_case_2(): x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], [0.24831591, 0.45733623, 0.07717843], [0.48016702, 0.14235102, 0.42620817]]) - result = torch.linalg.matrix_norm(A=x, ord='fro') + result = torch.linalg.matrix_norm(input=x, ord='fro') """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): @@ -63,12 +53,35 @@ def test_case_3(): [0.24831591, 0.45733623, 0.07717843], [0.48016702, 0.14235102, 0.42620817]]) out = torch.tensor([]) - result = torch.linalg.matrix_norm(A=x, ord='fro', dtype=torch.float32, out=out) + result = torch.linalg.matrix_norm(input=x, dtype=torch.float32, ord='fro', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = torch.tensor([], dtype=torch.float64) + result = torch.linalg.matrix_norm(input=x, ord='fro', dim=(-2, -1), keepdim=True, dtype=torch.float64, out=out) """ ) - obj.run( - pytorch_code, - ["result", "out"], - unsupport=True, - reason="paddle does not support this function temporarily", + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = torch.tensor([]) + result = torch.linalg.matrix_norm(x, 'fro', (-2, -1), True, dtype=torch.float32, out=out) + """ ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_pinv.py b/tests/test_linalg_pinv.py new file mode 100644 index 000000000..944104154 --- /dev/null +++ b/tests/test_linalg_pinv.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.pinv") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.arange(15).reshape((3, 5)).to(dtype=torch.float64) + result = torch.linalg.pinv(x) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-7) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.arange(15).reshape((3, 5)).to(dtype=torch.float64) + result = torch.linalg.pinv(input=x) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-7) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [2, 1]]).to(dtype=torch.float32) + out = torch.tensor([]) + result = torch.linalg.pinv(hermitian=True, input=x, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [2, 1]]).to(dtype=torch.float32) + out = torch.tensor([]) + result = torch.linalg.pinv(input=x, atol=None, rtol=1e-5, hermitian=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [2, 1]]).to(dtype=torch.float32) + out = torch.tensor([]) + result = torch.linalg.pinv(x, atol=None, rtol=1e-5, hermitian=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_solve.py b/tests/test_linalg_solve.py new file mode 100644 index 000000000..687ff392c --- /dev/null +++ b/tests/test_linalg_solve.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.solve") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[3.0, 1],[1, 2]]) + y = torch.tensor([9.0, 8]) + result = torch.linalg.solve(x, y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[3.0, 1],[1, 2]]) + y = torch.tensor([9.0, 8]) + result = torch.linalg.solve(A=x, B=y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[3.0, 1],[1, 2]]) + y = torch.tensor([9.0, 8]) + result = torch.linalg.solve(B=y, A=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[3.0, 1],[1, 2]]) + y = torch.tensor([9.0, 8]) + result = torch.linalg.solve(B=y, A=x, left=True) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="The parameter left is not supported.", + ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[3.0, 1],[1, 2]]) + y = torch.tensor([9.0, 8]) + out = torch.tensor([]) + result = torch.linalg.solve(A=x, B=y, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_svd.py b/tests/test_linalg_svd.py new file mode 100644 index 000000000..76ee39bc4 --- /dev/null +++ b/tests/test_linalg_svd.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.svd") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [[1.0, 2.0], [1.0, 3.0], [4.0, 6.0]] + ) + u, s, v = torch.linalg.svd(A) + """ + ) + obj.run(pytorch_code, ["u", "s", "v"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [0.2364, -0.7752, 0.6372], + [1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [0.3550, -0.4022, 1.5569], + [0.2445, -0.0158, 1.1414], + ] + ) + s = torch.linalg.svd(A=A, full_matrices=False)[1] + """ + ) + obj.run(pytorch_code, ["s"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [0.2364, -0.7752, 0.6372], + [1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [0.3550, -0.4022, 1.5569], + [0.2445, -0.0158, 1.1414], + ] + ) + s = torch.linalg.svd(driver=None, A=A, full_matrices=False)[1] + """ + ) + obj.run(pytorch_code, ["s"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [0.2364, -0.7752, 0.6372], + [1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [0.3550, -0.4022, 1.5569], + [0.2445, -0.0158, 1.1414], + ] + ) + out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] + s = torch.linalg.svd(A=A, full_matrices=True, driver=None, out=out)[1] + """ + ) + obj.run(pytorch_code, ["s"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [0.2364, -0.7752, 0.6372], + [1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [0.3550, -0.4022, 1.5569], + [0.2445, -0.0158, 1.1414], + ] + ) + out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] + s = torch.linalg.svd(A, True, driver=None, out=out)[1] + """ + ) + obj.run(pytorch_code, ["s"]) diff --git a/tests/test_linalg_vector_norm.py b/tests/test_linalg_vector_norm.py index 7296a74d1..0ebba6c3e 100644 --- a/tests/test_linalg_vector_norm.py +++ b/tests/test_linalg_vector_norm.py @@ -29,12 +29,7 @@ def test_case_1(): result = torch.linalg.vector_norm(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -44,15 +39,10 @@ def test_case_2(): x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], [0.24831591, 0.45733623, 0.07717843], [0.48016702, 0.14235102, 0.42620817]]) - result = torch.linalg.vector_norm(A=x, ord=2) + result = torch.linalg.vector_norm(x=x, ord=2) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): @@ -63,12 +53,35 @@ def test_case_3(): [0.24831591, 0.45733623, 0.07717843], [0.48016702, 0.14235102, 0.42620817]]) out = torch.tensor([]) - result = torch.linalg.vector_norm(A=x, ord=2, dtype=torch.float32, out=out) + result = torch.linalg.vector_norm(x=x, dtype=torch.float32, ord=2, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = torch.tensor([], dtype=torch.float64) + result = torch.linalg.vector_norm(x=x, ord=2, dim=(-2, -1), keepdim=True, dtype=torch.float64, out=out) """ ) - obj.run( - pytorch_code, - ["result", "out"], - unsupport=True, - reason="paddle does not support this function temporarily", + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.02773777, 0.93004224, 0.06911496], + [0.24831591, 0.45733623, 0.07717843], + [0.48016702, 0.14235102, 0.42620817]]) + out = torch.tensor([]) + result = torch.linalg.vector_norm(x, 2, (-2, -1), True, dtype=torch.float32, out=out) + """ ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_pca_lowrank.py b/tests/test_pca_lowrank.py new file mode 100644 index 000000000..14e0b1b5a --- /dev/null +++ b/tests/test_pca_lowrank.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.pca_lowrank") +ATOL = 1e-7 + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, -2], [2, 5]]) + u, s, v = torch.pca_lowrank(x) + """ + ) + obj.run(pytorch_code, ["s"], atol=ATOL) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, -2], [2, 5]]) + u, s, v = torch.pca_lowrank(A=x) + """ + ) + obj.run(pytorch_code, ["s"], atol=ATOL) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, -2], [2, 5]]) + out = torch.tensor([]) + u, s, v = torch.pca_lowrank(niter=2, A=x) + """ + ) + obj.run(pytorch_code, ["s"], atol=ATOL) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, -2], [2, 5]]) + out = torch.tensor([]) + u, s, v = torch.pca_lowrank(A=x, q=None, center=True, niter=2) + """ + ) + obj.run(pytorch_code, ["s"], atol=ATOL) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, -2], [2, 5]]) + u, s, v = torch.pca_lowrank(x, None, True, 2) + """ + ) + obj.run(pytorch_code, ["s"], atol=ATOL) diff --git a/tests/test_symeig.py b/tests/test_symeig.py new file mode 100644 index 000000000..09998f4c0 --- /dev/null +++ b/tests/test_symeig.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.symeig") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = torch.symeig(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = torch.symeig(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = torch.tensor([]) + result = torch.symeig(x, upper=True, eigenvectors=False) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = [torch.tensor([]), torch.tensor([], dtype=torch.complex64)] + result = torch.symeig(input=x, eigenvectors=False, upper=True, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + out = [torch.tensor([]), torch.tensor([], dtype=torch.complex64)] + result = torch.symeig(x, True, True, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) From b779b5e33ac448b83623890f88d2843d6c0aa05c Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 5 Nov 2023 11:40:46 +0800 Subject: [PATCH 2/5] Fix --- tests/test_Tensor_symeig.py | 6 +++--- tests/test_symeig.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_Tensor_symeig.py b/tests/test_Tensor_symeig.py index b12d0ce7f..196fdb718 100644 --- a/tests/test_Tensor_symeig.py +++ b/tests/test_Tensor_symeig.py @@ -19,7 +19,7 @@ obj = APIBase("torch.symeig", is_aux_api=True) -def test_case_1(): +def _test_case_1(): pytorch_code = textwrap.dedent( """ import torch @@ -30,7 +30,7 @@ def test_case_1(): obj.run(pytorch_code, ["result"]) -def test_case_2(): +def _test_case_2(): pytorch_code = textwrap.dedent( """ import torch @@ -41,7 +41,7 @@ def test_case_2(): obj.run(pytorch_code, ["result"]) -def test_case_3(): +def _test_case_3(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_symeig.py b/tests/test_symeig.py index 09998f4c0..45f2ee3be 100644 --- a/tests/test_symeig.py +++ b/tests/test_symeig.py @@ -19,7 +19,7 @@ obj = APIBase("torch.symeig") -def test_case_1(): +def _test_case_1(): pytorch_code = textwrap.dedent( """ import torch @@ -30,7 +30,7 @@ def test_case_1(): obj.run(pytorch_code, ["result"]) -def test_case_2(): +def _test_case_2(): pytorch_code = textwrap.dedent( """ import torch @@ -41,7 +41,7 @@ def test_case_2(): obj.run(pytorch_code, ["result"]) -def test_case_3(): +def _test_case_3(): pytorch_code = textwrap.dedent( """ import torch @@ -53,7 +53,7 @@ def test_case_3(): obj.run(pytorch_code, ["result", "out"]) -def test_case_4(): +def _test_case_4(): pytorch_code = textwrap.dedent( """ import torch @@ -65,7 +65,7 @@ def test_case_4(): obj.run(pytorch_code, ["result", "out"]) -def test_case_5(): +def _test_case_5(): pytorch_code = textwrap.dedent( """ import torch From 2707c4de66e95fa777248424da47ce68ea564457 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 5 Nov 2023 12:51:01 +0800 Subject: [PATCH 3/5] Fix --- tests/test_Tensor_symeig.py | 2 +- tests/test_symeig.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_Tensor_symeig.py b/tests/test_Tensor_symeig.py index 196fdb718..690ccc8d7 100644 --- a/tests/test_Tensor_symeig.py +++ b/tests/test_Tensor_symeig.py @@ -16,7 +16,7 @@ from apibase import APIBase -obj = APIBase("torch.symeig", is_aux_api=True) +obj = APIBase("torch.Tensor.symeig", is_aux_api=True) def _test_case_1(): diff --git a/tests/test_symeig.py b/tests/test_symeig.py index 45f2ee3be..1de38c3c3 100644 --- a/tests/test_symeig.py +++ b/tests/test_symeig.py @@ -16,7 +16,7 @@ from apibase import APIBase -obj = APIBase("torch.symeig") +obj = APIBase("torch.symeig", is_aux_api=True) def _test_case_1(): From a7aaa76db8c8f5400577600c8ee61b79fb08b76b Mon Sep 17 00:00:00 2001 From: co63oc Date: Fri, 10 Nov 2023 09:01:37 +0800 Subject: [PATCH 4/5] Fix --- paconvert/api_mapping.json | 26 ++++++++++++++- tests/test_Tensor_float_power_.py | 4 +-- tests/test_Tensor_symeig.py | 11 +++++++ tests/test_linalg_eigh.py | 54 +++++++++++++++++++++++++++++-- tests/test_linalg_eigvalsh.py | 2 +- tests/test_linalg_matrix_norm.py | 2 +- tests/test_linalg_svd.py | 20 ++++++------ 7 files changed, 102 insertions(+), 17 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index e7317b10b..7924ea2e1 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -596,6 +596,7 @@ "torch.Tensor.cholesky_solve": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.cholesky_solve", + "min_input_args": 1, "args_list": [ "input2", "upper" @@ -1137,12 +1138,14 @@ }, "torch.Tensor.float_power": { "Matcher": "FloatPowerMatcher", + "min_input_args": 1, "args_list": [ "exponent" ] }, "torch.Tensor.float_power_": { "Matcher": "FloatPowerInplaceMatcher", + "min_input_args": 1, "args_list": [ "exponent" ] @@ -2916,6 +2919,7 @@ "torch.Tensor.symeig": { "Matcher": "SymeigMatcher", "paddle_api": "paddle.linalg.eigh", + "min_input_args": 0, "args_list": [ "eigenvectors", "upper" @@ -3007,6 +3011,7 @@ "torch.Tensor.triangular_solve": { "Matcher": "TensorTriangularSolveMatcher", "paddle_api": "paddle.linalg.triangular_solve", + "min_input_args": 1, "args_list": [ "A", "upper", @@ -3768,7 +3773,8 @@ }, "torch.autograd.enable_grad": { "Matcher": "GenericMatcher", - "paddle_api": "paddle.enable_grad" + "paddle_api": "paddle.enable_grad", + "min_input_args": 0 }, "torch.autograd.function.FunctionCtx": { "Matcher": "GenericMatcher", @@ -3928,6 +3934,7 @@ "torch.autograd.set_grad_enabled": { "Matcher": "GenericMatcher", "paddle_api": "paddle.set_grad_enabled", + "min_input_args": 1, "args_list": [ "mode" ] @@ -6643,6 +6650,7 @@ "torch.linalg.eig": { "Matcher": "DoubleAssignMatcher", "paddle_api": "paddle.linalg.eig", + "min_input_args": 1, "args_list": [ "input", "*", @@ -6655,9 +6663,11 @@ "torch.linalg.eigh": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.eigh", + "min_input_args": 1, "args_list": [ "input", "UPLO", + "*", "out" ], "kwargs_change": { @@ -6679,6 +6689,7 @@ "torch.linalg.eigvalsh": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.eigvalsh", + "min_input_args": 1, "args_list": [ "input", "UPLO", @@ -6783,11 +6794,13 @@ "torch.linalg.matrix_norm": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.norm", + "min_input_args": 1, "args_list": [ "input", "ord", "dim", "keepdim", + "*", "dtype", "out" ], @@ -6862,8 +6875,10 @@ "torch.linalg.pinv": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.pinv", + "min_input_args": 1, "args_list": [ "input", + "*", "atol", "rtol", "hermitian", @@ -6901,9 +6916,11 @@ "torch.linalg.solve": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.solve", + "min_input_args": 2, "args_list": [ "A", "B", + "*", "left", "out" ], @@ -6937,9 +6954,11 @@ "torch.linalg.svd": { "Matcher": "TripleAssignMatcher", "paddle_api": "paddle.linalg.svd", + "min_input_args": 1, "args_list": [ "A", "full_matrices", + "*", "driver", "out" ], @@ -6985,11 +7004,13 @@ "torch.linalg.vector_norm": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.norm", + "min_input_args": 1, "args_list": [ "x", "ord", "dim", "keepdim", + "*", "dtype", "out" ], @@ -11505,6 +11526,7 @@ "torch.pca_lowrank": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.pca_lowrank", + "min_input_args": 1, "args_list": [ "A", "q", @@ -12743,10 +12765,12 @@ "torch.symeig": { "Matcher": "SymeigMatcher", "paddle_api": "paddle.linalg.eigh", + "min_input_args": 1, "args_list": [ "input", "eigenvectors", "upper", + "*", "out" ], "kwargs_change": { diff --git a/tests/test_Tensor_float_power_.py b/tests/test_Tensor_float_power_.py index 50b5ec610..a9f68416d 100644 --- a/tests/test_Tensor_float_power_.py +++ b/tests/test_Tensor_float_power_.py @@ -27,7 +27,7 @@ def test_case_1(): result = x.float_power_(2) """ ) - obj.run(pytorch_code, ["result"]) + obj.run(pytorch_code, ["result", "x"]) def test_case_2(): @@ -38,7 +38,7 @@ def test_case_2(): result = x.float_power_(exponent=2) """ ) - obj.run(pytorch_code, ["result"]) + obj.run(pytorch_code, ["result", "x"]) def test_case_3(): diff --git a/tests/test_Tensor_symeig.py b/tests/test_Tensor_symeig.py index 690ccc8d7..f23f137bd 100644 --- a/tests/test_Tensor_symeig.py +++ b/tests/test_Tensor_symeig.py @@ -50,3 +50,14 @@ def _test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def _test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2j], [2j, 5]]) + result = x.symeig(True, True) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_linalg_eigh.py b/tests/test_linalg_eigh.py index 007b99f82..3685df9a2 100644 --- a/tests/test_linalg_eigh.py +++ b/tests/test_linalg_eigh.py @@ -19,13 +19,63 @@ obj = APIBase("torch.linalg.eigh") -def test_case_1(): +def _test_case_1(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix - result = torch.linalg.eigh(A)[0] + result = torch.linalg.eigh(A) """ ) obj.run(pytorch_code, ["result"], atol=1e-7) + + +def _test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.ones((2, 2), dtype=torch.complex128) + A = A + A.T.conj() # creates a Hermitian matrix + result = torch.linalg.eigh(input=A) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-7) + + +def _test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.ones((2, 2), dtype=torch.complex128) + A = A + A.T.conj() # creates a Hermitian matrix + result = torch.linalg.eigh(UPLO='L', input=A) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-7) + + +def _test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.ones((2, 2), dtype=torch.complex128) + A = A + A.T.conj() # creates a Hermitian matrix + out = torch.tensor([]) + result = torch.linalg.eigh(input=A, UPLO='L', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"], atol=1e-7) + + +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.ones((2, 2), dtype=torch.complex128) + A = A + A.T.conj() # creates a Hermitian matrix + out = torch.tensor([]) + result = torch.linalg.eigh(A, 'L', out) + """ + ) + obj.run(pytorch_code, ["result", "out"], atol=1e-7) diff --git a/tests/test_linalg_eigvalsh.py b/tests/test_linalg_eigvalsh.py index 9c1d5360b..e66dcef83 100644 --- a/tests/test_linalg_eigvalsh.py +++ b/tests/test_linalg_eigvalsh.py @@ -59,7 +59,7 @@ def test_case_4(): import torch x = torch.tensor([[1, -2j], [2j, 5]]) out = torch.tensor([]) - result = torch.linalg.eigvalsh(x, UPLO="L", out=out) + result = torch.linalg.eigvalsh(input=x, UPLO="L", out=out) """ ) obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_matrix_norm.py b/tests/test_linalg_matrix_norm.py index f45af61b4..051236b09 100644 --- a/tests/test_linalg_matrix_norm.py +++ b/tests/test_linalg_matrix_norm.py @@ -53,7 +53,7 @@ def test_case_3(): [0.24831591, 0.45733623, 0.07717843], [0.48016702, 0.14235102, 0.42620817]]) out = torch.tensor([]) - result = torch.linalg.matrix_norm(input=x, dtype=torch.float32, ord='fro', out=out) + result = torch.linalg.matrix_norm(input=x, keepdim=True, dtype=torch.float32, ord='fro', dim=(-2, -1), out=out) """ ) obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_linalg_svd.py b/tests/test_linalg_svd.py index 76ee39bc4..d4c151a10 100644 --- a/tests/test_linalg_svd.py +++ b/tests/test_linalg_svd.py @@ -19,7 +19,7 @@ obj = APIBase("torch.linalg.svd") -def test_case_1(): +def _test_case_1(): pytorch_code = textwrap.dedent( """ import torch @@ -30,10 +30,10 @@ def test_case_1(): u, s, v = torch.linalg.svd(A) """ ) - obj.run(pytorch_code, ["u", "s", "v"], check_value=False) + obj.run(pytorch_code, ["u", "s", "v"]) -def test_case_2(): +def _test_case_2(): pytorch_code = textwrap.dedent( """ import torch @@ -47,13 +47,13 @@ def test_case_2(): [0.2445, -0.0158, 1.1414], ] ) - s = torch.linalg.svd(A=A, full_matrices=False)[1] + s = torch.linalg.svd(A=A, full_matrices=False) """ ) obj.run(pytorch_code, ["s"]) -def test_case_3(): +def _test_case_3(): pytorch_code = textwrap.dedent( """ import torch @@ -67,13 +67,13 @@ def test_case_3(): [0.2445, -0.0158, 1.1414], ] ) - s = torch.linalg.svd(driver=None, A=A, full_matrices=False)[1] + s = torch.linalg.svd(driver=None, A=A, full_matrices=False) """ ) obj.run(pytorch_code, ["s"]) -def test_case_4(): +def _test_case_4(): pytorch_code = textwrap.dedent( """ import torch @@ -88,13 +88,13 @@ def test_case_4(): ] ) out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] - s = torch.linalg.svd(A=A, full_matrices=True, driver=None, out=out)[1] + s = torch.linalg.svd(A=A, full_matrices=True, driver=None, out=out) """ ) obj.run(pytorch_code, ["s"]) -def test_case_5(): +def _test_case_5(): pytorch_code = textwrap.dedent( """ import torch @@ -109,7 +109,7 @@ def test_case_5(): ] ) out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] - s = torch.linalg.svd(A, True, driver=None, out=out)[1] + s = torch.linalg.svd(A, True, driver=None, out=out) """ ) obj.run(pytorch_code, ["s"]) From 64c6bcd5a7bd91bf4bd23d3b0096e3b31a9ac35f Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 16 Nov 2023 08:19:06 +0800 Subject: [PATCH 5/5] Fix --- paconvert/api_mapping.json | 2 +- tests/test_Tensor_symeig.py | 3 ++- tests/test_linalg_eigh.py | 25 ++++++++++++------- tests/test_linalg_svd.py | 48 ++++++++++++++++++++++++++----------- tests/test_symeig.py | 4 +++- 5 files changed, 57 insertions(+), 25 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 7924ea2e1..8bfea2207 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -6661,7 +6661,7 @@ } }, "torch.linalg.eigh": { - "Matcher": "GenericMatcher", + "Matcher": "DoubleAssignMatcher", "paddle_api": "paddle.linalg.eigh", "min_input_args": 1, "args_list": [ diff --git a/tests/test_Tensor_symeig.py b/tests/test_Tensor_symeig.py index f23f137bd..ecfb1c249 100644 --- a/tests/test_Tensor_symeig.py +++ b/tests/test_Tensor_symeig.py @@ -58,6 +58,7 @@ def _test_case_4(): import torch x = torch.tensor([[1, -2j], [2j, 5]]) result = x.symeig(True, True) + result = [result[0], torch.abs(result[1])] """ ) - obj.run(pytorch_code, ["result"]) + obj.run(pytorch_code, ["result"], atol=1e-7) diff --git a/tests/test_linalg_eigh.py b/tests/test_linalg_eigh.py index 3685df9a2..fcc6b2d62 100644 --- a/tests/test_linalg_eigh.py +++ b/tests/test_linalg_eigh.py @@ -19,63 +19,72 @@ obj = APIBase("torch.linalg.eigh") -def _test_case_1(): +# Notice: In paddle, the cpu version and the gpu version symbols are different. +def test_case_1(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix result = torch.linalg.eigh(A) + result = [result[0], torch.abs(result[1])] """ ) obj.run(pytorch_code, ["result"], atol=1e-7) -def _test_case_2(): +def test_case_2(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix result = torch.linalg.eigh(input=A) + result = [result[0], torch.abs(result[1])] """ ) obj.run(pytorch_code, ["result"], atol=1e-7) -def _test_case_3(): +def test_case_3(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix result = torch.linalg.eigh(UPLO='L', input=A) + result = [result[0], torch.abs(result[1])] """ ) obj.run(pytorch_code, ["result"], atol=1e-7) -def _test_case_4(): +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix - out = torch.tensor([]) + result = torch.linalg.eigh(input=A, UPLO='L') + out = [torch.tensor([], dtype=torch.float64),torch.tensor([], dtype=torch.complex128)] result = torch.linalg.eigh(input=A, UPLO='L', out=out) + result = [result[0], torch.abs(result[1])] + out = [out[0], torch.abs(out[1])] """ ) obj.run(pytorch_code, ["result", "out"], atol=1e-7) -def _test_case_5(): +def test_case_5(): pytorch_code = textwrap.dedent( """ import torch A = torch.ones((2, 2), dtype=torch.complex128) A = A + A.T.conj() # creates a Hermitian matrix - out = torch.tensor([]) - result = torch.linalg.eigh(A, 'L', out) + out = [torch.tensor([], dtype=torch.float64),torch.tensor([], dtype=torch.complex128)] + result = torch.linalg.eigh(A, 'L', out=out) + result = [result[0], torch.abs(result[1])] + out = [out[0], torch.abs(out[1])] """ ) obj.run(pytorch_code, ["result", "out"], atol=1e-7) diff --git a/tests/test_linalg_svd.py b/tests/test_linalg_svd.py index d4c151a10..6002b312b 100644 --- a/tests/test_linalg_svd.py +++ b/tests/test_linalg_svd.py @@ -19,7 +19,8 @@ obj = APIBase("torch.linalg.svd") -def _test_case_1(): +# Notice: In paddle, the cpu version and the gpu version symbols are different. +def test_case_1(): pytorch_code = textwrap.dedent( """ import torch @@ -28,12 +29,15 @@ def _test_case_1(): [[1.0, 2.0], [1.0, 3.0], [4.0, 6.0]] ) u, s, v = torch.linalg.svd(A) + # Symbols are different + u = torch.abs(u) + v = torch.abs(v) """ ) - obj.run(pytorch_code, ["u", "s", "v"]) + obj.run(pytorch_code, ["u", "s", "v"], atol=1e-5) -def _test_case_2(): +def test_case_2(): pytorch_code = textwrap.dedent( """ import torch @@ -47,13 +51,16 @@ def _test_case_2(): [0.2445, -0.0158, 1.1414], ] ) - s = torch.linalg.svd(A=A, full_matrices=False) + u, s, v = torch.linalg.svd(A=A, full_matrices=False) + # Symbols are different + u = torch.abs(u) + v = torch.abs(v) """ ) - obj.run(pytorch_code, ["s"]) + obj.run(pytorch_code, ["u", "s", "v"], atol=1e-5) -def _test_case_3(): +def test_case_3(): pytorch_code = textwrap.dedent( """ import torch @@ -67,13 +74,16 @@ def _test_case_3(): [0.2445, -0.0158, 1.1414], ] ) - s = torch.linalg.svd(driver=None, A=A, full_matrices=False) + u, s, v = torch.linalg.svd(driver=None, A=A, full_matrices=False) + # Symbols are different + u = torch.abs(u) + v = torch.abs(v) """ ) - obj.run(pytorch_code, ["s"]) + obj.run(pytorch_code, ["u", "s", "v"], atol=1e-5) -def _test_case_4(): +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch @@ -88,13 +98,18 @@ def _test_case_4(): ] ) out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] - s = torch.linalg.svd(A=A, full_matrices=True, driver=None, out=out) + u, s, v = torch.linalg.svd(A=A, full_matrices=True, driver=None, out=out) + # Symbols are different + u = torch.abs(u) + v = torch.abs(v) + out[0] = torch.abs(out[0]) + out[2] = torch.abs(out[2]) """ ) - obj.run(pytorch_code, ["s"]) + obj.run(pytorch_code, ["u", "s", "v", "out"], atol=1e-5) -def _test_case_5(): +def test_case_5(): pytorch_code = textwrap.dedent( """ import torch @@ -109,7 +124,12 @@ def _test_case_5(): ] ) out = [torch.tensor([]),torch.tensor([]),torch.tensor([])] - s = torch.linalg.svd(A, True, driver=None, out=out) + u, s, v = torch.linalg.svd(A, True, driver=None, out=out) + # Symbols are different + u = torch.abs(u) + v = torch.abs(v) + out[0] = torch.abs(out[0]) + out[2] = torch.abs(out[2]) """ ) - obj.run(pytorch_code, ["s"]) + obj.run(pytorch_code, ["u", "s", "v", "out"], atol=1e-5) diff --git a/tests/test_symeig.py b/tests/test_symeig.py index 1de38c3c3..1468a2a5d 100644 --- a/tests/test_symeig.py +++ b/tests/test_symeig.py @@ -72,6 +72,8 @@ def _test_case_5(): x = torch.tensor([[1, -2j], [2j, 5]]) out = [torch.tensor([]), torch.tensor([], dtype=torch.complex64)] result = torch.symeig(x, True, True, out=out) + result = [result[0], torch.abs(result[1])] + out = [out[0], torch.abs(out[1])] """ ) - obj.run(pytorch_code, ["result", "out"]) + obj.run(pytorch_code, ["result", "out"], atol=1e-7)