diff --git a/ivy/data_classes/array/experimental/linear_algebra.py b/ivy/data_classes/array/experimental/linear_algebra.py index 865c4a1296942..a7f1c744ccea0 100644 --- a/ivy/data_classes/array/experimental/linear_algebra.py +++ b/ivy/data_classes/array/experimental/linear_algebra.py @@ -110,6 +110,29 @@ def diagflat( out=out, ) + def lu( + self: ivy.Array, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.lu. This method simply + wraps the function, and so the docstring for ivy.lu also applies to + this method with minimal changes. + + Examples + -------- + >>> x = ivy.array([[1.0,2.0],[3.0,4.0]]) + >>> ivy.lu(x) + ivy.array([[0., 1.], + [1., 0.]]), + ivy.array([[1. , 0. ], + [0.33333334, 1. ]]), + ivy.array([[3. , 4. ], + [0. , 0.66666663]]) + """ + return ivy.lu(self._data, out=out) + def kron( self: ivy.Array, b: ivy.Array, diff --git a/ivy/data_classes/container/experimental/linear_algebra.py b/ivy/data_classes/container/experimental/linear_algebra.py index 697d0ef749e52..1381211c911a1 100644 --- a/ivy/data_classes/container/experimental/linear_algebra.py +++ b/ivy/data_classes/container/experimental/linear_algebra.py @@ -349,6 +349,54 @@ def kron( out=out, ) + @staticmethod + def static_lu( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + return ContainerBase.cont_multi_map_in_function( + "lu", + x, + out=out, + key_chains=key_chains, + to_apply=to_apply, + ) + + def lu( + self: ivy.Container, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ivy.Container instance method variant of ivy.diagflat. This method + simply wraps the function, and so the docstring for ivy.diagflat also + applies to this method with minimal changes. + + Examples + -------- + >>> x = ivy.array([[[1., 0.], + [0., 1.]], + [[2., 0.], + [0., 2.]]]) + >>> ivy.matrix_exp(x) + ivy.array([[[2.7183, 1.0000], + [1.0000, 2.7183]], + [[7.3891, 1.0000], + [1.0000, 7.3891]]]) + """ + return self.static_lu( + self, + key_chains=key_chains, + to_apply=to_apply, + out=out, + ) + @staticmethod def static_matrix_exp( x: Union[ivy.Array, ivy.NativeArray, ivy.Container], diff --git a/ivy/functional/backends/jax/experimental/linear_algebra.py b/ivy/functional/backends/jax/experimental/linear_algebra.py index 0f1841e41f1a6..3389e8c5ff7c5 100644 --- a/ivy/functional/backends/jax/experimental/linear_algebra.py +++ b/ivy/functional/backends/jax/experimental/linear_algebra.py @@ -13,6 +13,15 @@ from . import backend_version +def lu( + x: JaxArray, + /, + *, + out: Optional[JaxArray] = None, +) -> JaxArray: + return jla.lu(x) + + def diagflat( x: JaxArray, /, diff --git a/ivy/functional/backends/mxnet/experimental/linear_algebra.py b/ivy/functional/backends/mxnet/experimental/linear_algebra.py index dd31f5eeb070d..7d2a97dd5081b 100644 --- a/ivy/functional/backends/mxnet/experimental/linear_algebra.py +++ b/ivy/functional/backends/mxnet/experimental/linear_algebra.py @@ -49,6 +49,15 @@ def kron( raise IvyNotImplementedException() +def lu( + x: Union[(None, mx.ndarray.NDArray)], + /, + *, + out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, +) -> Union[(None, mx.ndarray.NDArray)]: + raise IvyNotImplementedException() + + def matrix_exp( x: Union[(None, mx.ndarray.NDArray)], /, diff --git a/ivy/functional/backends/numpy/experimental/linear_algebra.py b/ivy/functional/backends/numpy/experimental/linear_algebra.py index 98e87e7efafa3..4936ccf8f124d 100644 --- a/ivy/functional/backends/numpy/experimental/linear_algebra.py +++ b/ivy/functional/backends/numpy/experimental/linear_algebra.py @@ -1,6 +1,7 @@ import math from typing import Optional, Tuple, Sequence, Union, Any import numpy as np +import scipy.linalg as sla import ivy from ivy.func_wrapper import with_supported_dtypes, with_unsupported_dtypes @@ -10,6 +11,15 @@ from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size +def lu( + x: np.ndarray, + /, + *, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + return sla.lu(x) + + def diagflat( x: np.ndarray, /, diff --git a/ivy/functional/backends/paddle/experimental/linear_algebra.py b/ivy/functional/backends/paddle/experimental/linear_algebra.py index eacf1acf4278b..d0334bf39e5e8 100644 --- a/ivy/functional/backends/paddle/experimental/linear_algebra.py +++ b/ivy/functional/backends/paddle/experimental/linear_algebra.py @@ -59,6 +59,16 @@ def kron( return paddle.kron(a, b) +def lu( + x: paddle.Tensor, + /, + *, + out: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + # return paddle.lu(x) + raise IvyNotImplementedException() + + def matrix_exp( x: paddle.Tensor, /, diff --git a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py index 4d54923e2e850..022cad415f099 100644 --- a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py @@ -75,6 +75,15 @@ def diagflat( return ret +def lu( + x: Union[tf.Tensor, tf.Variable], + /, + *, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + return tf.linalg.lu(x) + + def kron( a: Union[tf.Tensor, tf.Variable], b: Union[tf.Tensor, tf.Variable], diff --git a/ivy/functional/backends/torch/experimental/linear_algebra.py b/ivy/functional/backends/torch/experimental/linear_algebra.py index ed3dee38a717d..ada63491551cc 100644 --- a/ivy/functional/backends/torch/experimental/linear_algebra.py +++ b/ivy/functional/backends/torch/experimental/linear_algebra.py @@ -100,6 +100,15 @@ def diagflat( diagflat.support_native_out = False +def lu( + x: torch.Tensor, + /, + *, + out: Optional[torch.Tensor] = None, +) -> torch.tensor: + return torch.linalg.lu(x) + + def kron( a: torch.Tensor, b: torch.Tensor, diff --git a/ivy/functional/ivy/experimental/linear_algebra.py b/ivy/functional/ivy/experimental/linear_algebra.py index f98ec8a896e14..e559691e48590 100644 --- a/ivy/functional/ivy/experimental/linear_algebra.py +++ b/ivy/functional/ivy/experimental/linear_algebra.py @@ -228,6 +228,33 @@ def diagflat( ) +@handle_exceptions +@handle_backend_invalid +@handle_nestable +@handle_array_like_without_promotion +@handle_out_argument +@to_native_arrays_and_back +@handle_device +def lu( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """Perform LU decomposition of a square matrix using Doolittle's method. + + Args: + ---- + - x: a square numpy array representing the input matrix + + Returns: + ------- + - L: Lower triangular matrix + - U: Upper triangular matrix + """ + return current_backend(x).lu(x, out=out) + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py index 655b8149eb478..c140f86fd90be 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py @@ -1380,6 +1380,33 @@ def test_kronecker(*, data, test_flags, backend_fw, fn_name, on_device): ) +@handle_test( + fn_tree="functional.ivy.experimental.lu", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=2, + min_value=-100, + max_value=100, + allow_nan=False, + shared_dtype=True, + ), + test_gradients=st.just(False), +) +def test_lu(dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + ) + + @handle_test( fn_tree="functional.ivy.experimental.make_svd_non_negative", data=_make_svd_nn_data(),