From 5b86abdcf0c45bbd9f6590c92145adc574078290 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 1 Apr 2024 19:14:51 +0530 Subject: [PATCH 1/6] Updated inv_ex function --- ivy/functional/frontends/torch/linalg.py | 2 -- .../test_frontends/test_torch/test_linalg.py | 12 +++++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 3df44b0ae3fa7..3b1b0c06a9d01 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -107,7 +107,6 @@ def eigvalsh(input, UPLO="L", *, out=None): def inv(A, *, out=None): return ivy.inv(A, out=out) - @to_ivy_arrays_and_back @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" @@ -126,7 +125,6 @@ def inv_ex(A, *, check_errors=False, out=None): info = ivy.zeros(A.shape[:-2], dtype=ivy.int32) return inv, info - @to_ivy_arrays_and_back @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 9686bdf183982..7a5c02ebcbb74 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -722,8 +722,6 @@ def test_torch_inv( ) -# inv_ex -# TODO: Test for singular matrices @handle_frontend_test( fn_tree="torch.linalg.inv_ex", dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), @@ -738,6 +736,11 @@ def test_torch_inv_ex( backend_fw, ): dtype, x = dtype_and_x + A = x[0] + + inv, info = inv_ex(A) + + # Assert inv and info using test_frontend_function helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -747,10 +750,13 @@ def test_torch_inv_ex( on_device=on_device, rtol=1e-03, atol=1e-02, - A=x[0], + A=A, + inv=inv, + info=info ) + # lu_factor @handle_frontend_test( fn_tree="torch.linalg.lu_factor", From c4c9ed3d94ce0809972e3e1c7e21a6812d73bb94 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 1 Apr 2024 13:47:50 +0000 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/torch/linalg.py | 2 ++ ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 3b1b0c06a9d01..3df44b0ae3fa7 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -107,6 +107,7 @@ def eigvalsh(input, UPLO="L", *, out=None): def inv(A, *, out=None): return ivy.inv(A, out=out) + @to_ivy_arrays_and_back @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" @@ -125,6 +126,7 @@ def inv_ex(A, *, check_errors=False, out=None): info = ivy.zeros(A.shape[:-2], dtype=ivy.int32) return inv, info + @to_ivy_arrays_and_back @with_supported_dtypes( {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 7a5c02ebcbb74..b02ddad51e5d4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -752,11 +752,10 @@ def test_torch_inv_ex( atol=1e-02, A=A, inv=inv, - info=info + info=info, ) - # lu_factor @handle_frontend_test( fn_tree="torch.linalg.lu_factor", From 5976572dc1a8c516c626bc2843843d54b49bddb5 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Tue, 2 Apr 2024 18:55:11 +0530 Subject: [PATCH 3/6] fix : test of jax_append --- .../frontends/jax/numpy/manipulations.py | 8 +++- .../test_jax/test_numpy/test_manipulations.py | 42 +++++++++++++------ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/manipulations.py b/ivy/functional/frontends/jax/numpy/manipulations.py index 61fe0fe6279ff..46edc35ac1185 100644 --- a/ivy/functional/frontends/jax/numpy/manipulations.py +++ b/ivy/functional/frontends/jax/numpy/manipulations.py @@ -10,9 +10,13 @@ @to_ivy_arrays_and_back def append(arr, values, axis=None): if axis is None: - return ivy.concat((ivy.flatten(arr), ivy.flatten(values)), axis=0) + if isinstance(arr, (list, tuple)): + arr = ivy.array(arr) + if isinstance(values, (list, tuple)): + values = ivy.array(values) + return ivy.concatenate((ivy.flatten(arr), ivy.flatten(values)), axis=0) else: - return ivy.concat((arr, values), axis=axis) + return ivy.concatenate((arr, values), axis=axis) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index 2d9d150748115..7c2f3844f54b8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -338,18 +338,36 @@ def test_jax_append( test_flags, ): input_dtype, values, axis = dtype_values_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - arr=values[0], - values=values[1], - axis=axis, - ) - + # Handle if axis is None + if axis is None: + # If axis is None, concatenate flattened arrays + expected_result = np.concatenate( + (np.ravel(values[0]), np.ravel(values[1])), axis=0 + ) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + arr=[values[0], values[1]], # Pass list of arrays + values=None, # No separate values to append + axis=axis, + expected_result=expected_result, # Pass expected result + ) + else: + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + arr=values[0], + values=values[1], + axis=axis, + ) # array_split @handle_frontend_test( From 0db65f07847d124d908f2e7def9c343a043af786 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Tue, 2 Apr 2024 13:27:30 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_frontends/test_jax/test_numpy/test_manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index 7c2f3844f54b8..61df95991f163 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -369,6 +369,7 @@ def test_jax_append( axis=axis, ) + # array_split @handle_frontend_test( fn_tree="jax.numpy.array_split", From 819c12436fc20fa5dee4de0dbc21c00f42b2b2b5 Mon Sep 17 00:00:00 2001 From: rohitsalla Date: Tue, 2 Apr 2024 19:04:04 +0530 Subject: [PATCH 5/6] Update test_manipulations.py --- .../test_frontends/test_jax/test_numpy/test_manipulations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index 61df95991f163..81baf6617188b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -310,6 +310,9 @@ def _squeeze_helper(draw): # ------------ # + + + # append @handle_frontend_test( fn_tree="jax.numpy.append", From d91ed865e55e656595bf316aa4e9fa52885ca03f Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Tue, 2 Apr 2024 13:34:51 +0000 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_frontends/test_jax/test_numpy/test_manipulations.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index 81baf6617188b..61df95991f163 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -310,9 +310,6 @@ def _squeeze_helper(draw): # ------------ # - - - # append @handle_frontend_test( fn_tree="jax.numpy.append",