Skip to content

Commit

Permalink
fix: least_squares implementation (#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Oct 29, 2024
1 parent 5eb444e commit 9d73de2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion exla/test/exla/nx_linalg_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
invert: 1,
matrix_power: 2
]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3]

@excluded_doctests @function_clause_error_doctests ++
@rounding_error_doctests ++
Expand Down
24 changes: 11 additions & 13 deletions nx/lib/nx/lin_alg.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do
@doc """
Return the least-squares solution to a linear matrix equation Ax = b.
## Options
* `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15`
## Examples
iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2]))
#Nx.Tensor<
f32[2]
[1.0000004768371582, -2.665601925855299e-7]
[0.9977624416351318, 0.0011188983917236328]
>
iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1]))
Expand Down Expand Up @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do
** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3}
"""
@doc from_backend: false
defn least_squares(a, b) do
defn least_squares(a, b, opts \\ []) do
opts = keyword!(opts, eps: 1.0e-15)

%T{type: a_type, shape: a_shape} = Nx.to_tensor(a)
a_size = Nx.rank(a_shape)
%T{type: b_type, shape: b_shape} = Nx.to_tensor(b)
Expand Down Expand Up @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do
)
end

case a_shape do
{m, n} when m == n ->
Nx.LinAlg.solve(a, b)

{m, n} when m != n ->
Nx.LinAlg.pinv(a, eps: 1.0e-15)
|> Nx.dot(b)

_ ->
nil
end
a
|> Nx.LinAlg.pinv(eps: opts[:eps])
|> Nx.dot(b)
end

defp apply_vectorized(tensor, fun) when is_function(fun, 1) do
Expand Down

0 comments on commit 9d73de2

Please sign in to comment.