Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lu #28406

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

lu #28406

Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 156 additions & 137 deletions ivy/data_classes/array/experimental/linear_algebra.py

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions ivy/data_classes/container/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,55 @@ def kron(
out=out,
)

@staticmethod
def static_lu(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
return ContainerBase.cont_multi_map_in_function(
"lu",
x,
out=out,
key_chains=key_chains,
to_apply=to_apply,
)

def lu(
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.diagflat. This method
simply wraps the function, and so the docstring for ivy.diagflat also
applies to this method with minimal changes.

Examples
--------
>>> x = ivy.array([[[1., 0.],
[0., 1.]],
[[2., 0.],
[0., 2.]]])
>>> ivy.matrix_exp(x)
ivy.array([[[2.7183, 1.0000],
[1.0000, 2.7183]],
[[7.3891, 1.0000],
[1.0000, 7.3891]]])
"""
return self.static_lu(
self,
key_chains=key_chains,
to_apply=to_apply,
out=out,
)


@staticmethod
def static_matrix_exp(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/jax/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from . import backend_version


def lu(
x: JaxArray,
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jla.lu(x)

def diagflat(
x: JaxArray,
/,
Expand Down
133 changes: 71 additions & 62 deletions ivy/functional/backends/mxnet/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,101 @@


def eigh_tridiagonal(
alpha: Union[(None, mx.ndarray.NDArray)],
beta: Union[(None, mx.ndarray.NDArray)],
/,
*,
eigvals_only: bool = True,
select: str = "a",
select_range: Optional[
Union[(Tuple[(int, int)], List[int], None, mx.ndarray.NDArray)]
] = None,
tol: Optional[float] = None,
alpha: Union[(None, mx.ndarray.NDArray)],
beta: Union[(None, mx.ndarray.NDArray)],
/,
*,
eigvals_only: bool = True,
select: str = "a",
select_range: Optional[
Union[(Tuple[(int, int)], List[int], None, mx.ndarray.NDArray)]
] = None,
tol: Optional[float] = None,
) -> Union[
(
None,
mx.ndarray.NDArray,
Tuple[(Union[(None, mx.ndarray.NDArray)], Union[(None, mx.ndarray.NDArray)])],
None,
mx.ndarray.NDArray,
Tuple[(Union[(None, mx.ndarray.NDArray)], Union[(None, mx.ndarray.NDArray)])],
)
]:
raise IvyNotImplementedException()


def diagflat(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
offset: int = 0,
padding_value: float = 0,
align: str = "RIGHT_LEFT",
num_rows: Optional[int] = None,
num_cols: Optional[int] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
offset: int = 0,
padding_value: float = 0,
align: str = "RIGHT_LEFT",
num_rows: Optional[int] = None,
num_cols: Optional[int] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
):
raise IvyNotImplementedException()


def kron(
a: Union[(None, mx.ndarray.NDArray)],
b: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
a: Union[(None, mx.ndarray.NDArray)],
b: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def lu(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def matrix_exp(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def eig(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Tuple[None]:
raise IvyNotImplementedException()


def eigvals(
x: Union[(None, mx.ndarray.NDArray)], /
x: Union[(None, mx.ndarray.NDArray)], /
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def adjoint(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def solve_triangular(
x1: Union[(None, mx.ndarray.NDArray)],
x2: Union[(None, mx.ndarray.NDArray)],
/,
*,
upper: bool = True,
adjoint: bool = False,
unit_diagonal: bool = False,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x1: Union[(None, mx.ndarray.NDArray)],
x2: Union[(None, mx.ndarray.NDArray)],
/,
*,
upper: bool = True,
adjoint: bool = False,
unit_diagonal: bool = False,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
# Multiplying with a mask matrix can stop gradients on the diagonal.
if unit_diagonal:
Expand All @@ -102,30 +111,30 @@ def solve_triangular(


def multi_dot(
x: Sequence[Union[(None, mx.ndarray.NDArray)]],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Sequence[Union[(None, mx.ndarray.NDArray)]],
/,
*,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> None:
raise IvyNotImplementedException()


def cond(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
p: Optional[Union[(None, int, str)]] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
p: Optional[Union[(None, int, str)]] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def dot(
a: mx.ndarray.NDArray,
b: mx.ndarray.NDArray,
/,
*,
out: Optional[mx.ndarray.NDArray] = None,
a: mx.ndarray.NDArray,
b: mx.ndarray.NDArray,
/,
*,
out: Optional[mx.ndarray.NDArray] = None,
) -> mx.ndarray.NDArray:
return mx.symbol.dot(a, b, out=out)

Expand Down
9 changes: 9 additions & 0 deletions ivy/functional/backends/numpy/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Optional, Tuple, Sequence, Union, Any
import numpy as np
import scipy.linalg as sla

import ivy
from ivy.func_wrapper import with_supported_dtypes, with_unsupported_dtypes
Expand All @@ -10,6 +11,14 @@
from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size


def lu(
x: np.ndarray,
/,
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return sla.lu(x)

def diagflat(
x: np.ndarray,
/,
Expand Down
Loading
Loading