diff --git a/ivy/functional/frontends/jax/numpy/manipulations.py b/ivy/functional/frontends/jax/numpy/manipulations.py index 61fe0fe6279ff..46edc35ac1185 100644 --- a/ivy/functional/frontends/jax/numpy/manipulations.py +++ b/ivy/functional/frontends/jax/numpy/manipulations.py @@ -10,9 +10,13 @@ @to_ivy_arrays_and_back def append(arr, values, axis=None): if axis is None: - return ivy.concat((ivy.flatten(arr), ivy.flatten(values)), axis=0) + if isinstance(arr, (list, tuple)): + arr = ivy.array(arr) + if isinstance(values, (list, tuple)): + values = ivy.array(values) + return ivy.concatenate((ivy.flatten(arr), ivy.flatten(values)), axis=0) else: - return ivy.concat((arr, values), axis=axis) + return ivy.concatenate((arr, values), axis=axis) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index 2d9d150748115..61df95991f163 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -338,17 +338,36 @@ def test_jax_append( test_flags, ): input_dtype, values, axis = dtype_values_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - arr=values[0], - values=values[1], - axis=axis, - ) + # Handle if axis is None + if axis is None: + # If axis is None, concatenate flattened arrays + expected_result = np.concatenate( + (np.ravel(values[0]), np.ravel(values[1])), axis=0 + ) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + arr=[values[0], values[1]], # Pass list of arrays + values=None, # No separate values to append + axis=axis, + expected_result=expected_result, # Pass expected result + ) + else: + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + arr=values[0], + values=values[1], + axis=axis, + ) # array_split diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 9686bdf183982..b02ddad51e5d4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -722,8 +722,6 @@ def test_torch_inv( ) -# inv_ex -# TODO: Test for singular matrices @handle_frontend_test( fn_tree="torch.linalg.inv_ex", dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), @@ -738,6 +736,11 @@ def test_torch_inv_ex( backend_fw, ): dtype, x = dtype_and_x + A = x[0] + + inv, info = inv_ex(A) + + # Assert inv and info using test_frontend_function helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -747,7 +750,9 @@ def test_torch_inv_ex( on_device=on_device, rtol=1e-03, atol=1e-02, - A=x[0], + A=A, + inv=inv, + info=info, )