Skip to content

Commit

Permalink
[lang] MatrixType refactor: Support eig(), sym_eig(), solve() (taichi…
Browse files Browse the repository at this point in the history
…-dev#6627)

Issue: taichi-dev#5819

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 5fb0358 commit 30dcc9a
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 16 deletions.
37 changes: 21 additions & 16 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def svd3d(A, dt, iters=None):


@func
def eig2x2(A, dt):
def _eig2x2(A, dt):
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
Expand Down Expand Up @@ -241,7 +241,7 @@ def eig2x2(A, dt):


@func
def sym_eig2x2(A, dt):
def _sym_eig2x2(A, dt):
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
Expand All @@ -254,6 +254,7 @@ def sym_eig2x2(A, dt):
eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value.
eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector.
"""
assert all(A == A.transpose()), "A needs to be symmetric"
tr = A.trace()
det = A.determinant()
gap = tr**2 - 4 * det
Expand All @@ -276,7 +277,7 @@ def sym_eig2x2(A, dt):


@func
def sym_eig3x3(A, dt):
def _sym_eig3x3(A, dt):
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 3x3 real symmetric matrix using Cardano's method.
Mathematical concept refers to https://www.mpi-hd.mpg.de/personalhomes/globes/3x3/.
Expand All @@ -289,6 +290,7 @@ def sym_eig3x3(A, dt):
eigenvalues (ti.Vector(3)): The eigenvalues. Each entry store one eigen value.
eigenvectors (ti.Matrix(3, 3)): The eigenvectors. Each column stores one eigenvector.
"""
assert all(A == A.transpose()), "A needs to be symmetric"
M_SQRT3 = 1.73205080756887729352744634151
m = A.trace()
dd = A[0, 1] * A[0, 1]
Expand Down Expand Up @@ -456,7 +458,7 @@ def eig(A, dt=None):
if dt is None:
dt = impl.get_runtime().default_fp
if A.n == 2:
return eig2x2(A, dt)
return _eig2x2(A, dt)
raise Exception("Eigen solver only supports 2D matrices.")


Expand All @@ -474,13 +476,12 @@ def sym_eig(A, dt=None):
eigenvalues (ti.Vector(n)): The eigenvalues. Each entry store one eigen value.
eigenvectors (ti.Matrix(n, n)): The eigenvectors. Each column stores one eigenvector.
"""
assert all(A == A.transpose()), "A needs to be symmetric"
if dt is None:
dt = impl.get_runtime().default_fp
if A.n == 2:
return sym_eig2x2(A, dt)
return _sym_eig2x2(A, dt)
if A.n == 3:
return sym_eig3x3(A, dt)
return _sym_eig3x3(A, dt)
raise Exception("Symmetric eigen solver only supports 2D and 3D matrices.")


Expand Down Expand Up @@ -535,6 +536,18 @@ def _gauss_elimination_3x3(Ab, dt):
return x


@func
def _combine(A, b, dt):
n = static(A.n)
Ab = Matrix.zero(dt, n, n + 1)
for i in static(range(n)):
for j in static(range(n)):
Ab[i, j] = A[i, j]
for i in static(range(n)):
Ab[i, n] = b[i]
return Ab


def solve(A, b, dt=None):
"""Solve a matrix using Gauss elimination method.
Expand All @@ -551,15 +564,7 @@ def solve(A, b, dt=None):
assert A.m == b.n, "Matrix and Vector dimension dismatch"
if dt is None:
dt = impl.get_runtime().default_fp
nrow, ncol = static(A.n, A.n + 1)
Ab = expr_init(Matrix.zero(dt, nrow, ncol))
lhs = tuple([e.ptr for e in A.entries])
rhs = tuple([e.ptr for e in b.entries])
for i in range(nrow):
for j in range(nrow):
Ab(i, j)._assign(lhs[nrow * i + j])
for i in range(nrow):
Ab(i, nrow)._assign(rhs[i])
Ab = _combine(A, b, dt)
if A.n == 2:
return _gauss_elimination_2x2(Ab, dt)
if A.n == 3:
Expand Down
55 changes: 55 additions & 0 deletions tests/python/test_eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,58 @@ def test_sym_eig3x3_f32(a00):
fast_math=False)
def test_sym_eig3x3_f64(a00):
_test_sym_eig3x3(ti.f64, a00)


@pytest.mark.parametrize("func", [_test_eig2x2_real, _test_eig2x2_complex])
@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_eig2x2_f32_real_matrix_scalarize(func):
func(ti.f32)


@pytest.mark.parametrize("func", [_test_eig2x2_real, _test_eig2x2_complex])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_eig2x2_f64_real_matrix_scalarize(func):
func(ti.f64)


@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_sym_eig2x2_f32_real_matrix_scalarize():
_test_sym_eig2x2(ti.f32)


@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_sym_eig2x2_f64_real_matrix_scalarize():
_test_sym_eig2x2(ti.f64)


@pytest.mark.parametrize('a00', [i for i in range(10)])
@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_sym_eig3x3_f32_real_matrix_scalarize(a00):
_test_sym_eig3x3(ti.f32, a00)


@pytest.mark.parametrize('a00', [i for i in range(10)])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_sym_eig3x3_f64_real_matrix_scalarize(a00):
_test_sym_eig3x3(ti.f64, a00)
38 changes: 38 additions & 0 deletions tests/python/test_matrix_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,41 @@ def test_solve_3x3_f32(a00):
fast_math=False)
def test_solve_3x3_f64(a00):
_test_solve_3x3(ti.f64, a00)


@pytest.mark.parametrize('a00', [float(i) for i in range(10)])
@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_solve_2x2_f32_real_matrix_scalarize(a00):
_test_solve_2x2(ti.f32, a00)


@pytest.mark.parametrize('a00', [float(i) for i in range(10)])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_solve_2x2_f64_real_matrix_scalarize(a00):
_test_solve_2x2(ti.f64, a00)


@pytest.mark.parametrize('a00', [float(i) for i in range(10)])
@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_solve_3x3_f32_real_matrix_scalarize(a00):
_test_solve_3x3(ti.f32, a00)


@pytest.mark.parametrize('a00', [float(i) for i in range(10)])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_solve_3x3_f64_real_matrix_scalarize(a00):
_test_solve_3x3(ti.f64, a00)

0 comments on commit 30dcc9a

Please sign in to comment.