Skip to content

Commit

Permalink
Blocked Jacobi method for eigen decomposition (#1510)
Browse files Browse the repository at this point in the history
Co-authored-by: Paulo Valente <[email protected]>
  • Loading branch information
christianjgreen and polvalente authored Jan 13, 2025
1 parent 7ef59e2 commit 8dc7b29
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 377 deletions.
19 changes: 0 additions & 19 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1240,25 +1240,6 @@ defmodule Nx.BinaryBackend do
output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
end

@impl true
def eigh(
{%{type: output_type} = eigenvals_holder, eigenvecs_holder},
%{type: input_type, shape: input_shape} = tensor,
opts
) do
bin = to_binary(tensor)
rank = tuple_size(input_shape)
n = elem(input_shape, rank - 1)

{eigenvals, eigenvecs} =
bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
{vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
{vals_acc <> vals, vecs_acc <> vecs}
end)

{from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
end

@impl true
def lu(
{%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
Expand Down
272 changes: 0 additions & 272 deletions nx/lib/nx/binary_backend/matrix.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do

defp do_ts([], [], _idx, acc), do: acc

defp qr_decomposition(matrix, n, _eps) when n in 0..1 do
{[[1.0]], matrix}
end

defp qr_decomposition(matrix, n, eps) when n >= 2 do
# QR decomposition is performed by using Householder transform
# this function originally supported generic QR, but
# it is now only used by eigh. Because of this,
# we simplified the function signature to only
# support square matrices.

{q_matrix, r_matrix} =
for i <- 0..(n - 2)//1, reduce: {nil, matrix} do
{q, r} ->
h =
r
|> slice_matrix([i, i], [n - i, 1])
|> householder_reflector(n, eps)

# If we haven't allocated Q yet, let Q = H1
# TODO: Resolve inconsistent with the Householder reflector.
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
q =
if is_nil(q) do
h
else
dot_matrix_real(q, h)
end

r = dot_matrix_real(h, r)
{q, r}
end

{approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)}
end

defp raise_not_hermitian do
raise ArgumentError,
"matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
end

def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do
eps = opts[:eps]
max_iter = opts[:max_iter]

# Validate that the input is a Hermitian matrix using the relation A^* = A.
a = binary_to_matrix(input_data, input_type, input_shape)

is_hermitian =
a
|> transpose_matrix()
|> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end)
|> is_approximately_same?(a, eps)

unless is_hermitian do
raise_not_hermitian()
end

# Hessenberg decomposition
{h, q_h} = hessenberg_decomposition(a, n, eps)

# QR iteration for eigenvalues and eigenvectors
{eigenvals_diag, eigenvecs} =
Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} ->
# QR decomposition
{q_now, r_now} = qr_decomposition(a_old, n, eps)

# Update matrix A, Q
a_new = dot_matrix_real(r_now, q_now)
q_new = dot_matrix_real(q_old, q_now)

if is_approximately_same?(q_old, q_new, eps) do
{:halt, {a_new, q_new}}
else
{:cont, {a_new, q_new}}
end
end)

# Obtain the eigenvalues, which are the diagonal elements
indices_diag = for idx <- 0..(n - 1), do: [idx, idx]
eigenvals = get_matrix_elements(eigenvals_diag, indices_diag)

# In general, the eigenvalues of a Hermitian matrix are real numbers
eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1))

# Reduce the elements smaller than eps to zero
{eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type),
eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)}
end

defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do
{matrix, [[1.0]]}
end

defp hessenberg_decomposition(matrix, n, eps) do
# Hessenberg decomposition is performed by using Householder transform
{hess_matrix, q_matrix} =
for i <- 0..(n - 2)//1, reduce: {matrix, nil} do
{hess, q} ->
h =
hess
|> slice_matrix([i + 1, i], [n - i - 1, 1])
|> householder_reflector(n, eps)

# If we haven't allocated Q yet, let Q = H1
# TODO: Resolve inconsistent with the Householder reflector.
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
q =
if is_nil(q) do
h
else
dot_matrix_real(q, h)
end

# Hessenberg matrix H updating
h_adj = adjoint_matrix(h)

hess =
h
|> dot_matrix_real(hess)
|> dot_matrix_real(h_adj)

{hess, q}
end

{approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)}
end

defp is_approximately_same?(a, b, eps) do
# Determine if matrices `a` and `b` are equal in the range of eps
a
|> Enum.zip(b)
|> Enum.all?(fn {a_row, b_row} ->
a_row
|> Enum.zip(b_row)
|> Enum.all?(fn
{a_elem, b_elem} ->
abs_diff = Complex.abs(a_elem - b_elem)

abs_diff == :nan or abs_diff <= eps
end)
end)
end

def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do
a = binary_to_matrix(input_data, input_type, input_shape)
eps = opts[:eps]
Expand Down Expand Up @@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do
end)
end

## Householder helpers

defp householder_reflector(a, target_k, eps)

defp householder_reflector([], target_k, _eps) do
flat_list =
for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do
if col == row, do: 1, else: 0
end

Enum.chunk_every(flat_list, target_k)
end

defp householder_reflector(a, target_k, eps) do
{v, scale, is_complex} = householder_reflector_pivot(a, eps)

prefix_threshold = target_k - length(v)
v = List.duplicate(0, prefix_threshold) ++ v

# dot(v, v) = norm_v_squared, which can be calculated from norm_a as:
# norm_v_squared = norm_a_squared - a_0^2 + v_0^2

# execute I - 2 / norm_v_squared * outer(v, v)
{_, _, reflector_reversed} =
for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do
{row, col, acc} ->
row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor

# The current element in outer(v, v) is given by col_factor * row_factor
# and the current I element is 1 when row == col
identity_element = if row == col, do: 1, else: 0

result =
if row >= prefix_threshold and col >= prefix_threshold do
identity_element -
scale * col_factor * row_factor
else
identity_element
end

acc = [result | acc]

if col + 1 == target_k do
{row + 1, 0, acc}
else
{row, col + 1, acc}
end
end

# This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k)
{reflector, _, _} =
for x <- reflector_reversed, reduce: {[], [], 0} do
{result_acc, row_acc, col} ->
row_acc = [x | row_acc]

if col + 1 == target_k do
{[row_acc | result_acc], [], 0}
else
{result_acc, row_acc, col + 1}
end
end

reflector
end

defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do
# This is a trick so we can both calculate the norm of a_reverse and extract the
# head a the same time we reverse the array
# receives a_reverse as a list of numbers and returns the reflector as a
# k x k matrix

norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end)
norm_a_sq_1on = norm_a_squared - a_0 * a_0

if norm_a_sq_1on < eps do
{[1 | tail], 0, false}
else
v_0 =
if a_0 <= 0 do
a_0 - Complex.sqrt(norm_a_squared)
else
-norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared))
end

v_0_sq = v_0 * v_0
scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq)
v = [1 | Enum.map(tail, &(&1 / v_0))]
{v, scale, false}
end
end

defp householder_reflector_pivot([a_0 | tail], _eps) do
# complex case
norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2))
norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0)
norm_a = Complex.sqrt(norm_a_sq)

phase_a_0 = Complex.phase(a_0)
alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a

# u = x - alfa * e1
u_0 = a_0 + alfa
u = [u_0 | tail]
norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0)
norm_u = Complex.sqrt(norm_u_sq)

v = Enum.map(u, &(&1 / norm_u))
{v, 2, true}
end

## Matrix (2-D array) manipulation

defp dot_matrix([], _), do: 0
Expand All @@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do
end)
end

defp dot_matrix_real(m1, m2) do
Enum.map(m1, fn row ->
m2
|> transpose_matrix()
|> Enum.map(fn col ->
Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end)
end)
end)
end

defp adjoint_matrix([x | _] = m) when not is_list(x) do
Enum.map(m, &[Complex.conjugate(&1)])
end

defp adjoint_matrix(m) do
Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end)
end

defp transpose_matrix([x | _] = m) when not is_list(x) do
Enum.map(m, &[&1])
end
Expand Down
Loading

0 comments on commit 8dc7b29

Please sign in to comment.