Skip to content

Commit

Permalink
[lang] MatrixType refactor: Support svd(), polar_decompose() (#6636)
Browse files Browse the repository at this point in the history
Issue: #5819

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Nov 18, 2022
1 parent 559097a commit db6ab0d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 74 deletions.
107 changes: 35 additions & 72 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

from taichi.lang import impl, ops
from taichi.lang.impl import expr_init, get_runtime, grouped, static
from taichi.lang.impl import get_runtime, grouped, static
from taichi.lang.kernel_impl import func
from taichi.lang.matrix import Matrix, Vector
from taichi.types import f32, f64
Expand Down Expand Up @@ -50,7 +50,7 @@ def randn(dt=None):


@func
def polar_decompose2d(A, dt):
def _polar_decompose2d(A, dt):
"""Perform polar decomposition (A=UP) for 2x2 matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
Expand Down Expand Up @@ -88,7 +88,7 @@ def polar_decompose2d(A, dt):


@func
def polar_decompose3d(A, dt):
def _polar_decompose3d(A, dt):
"""Perform polar decomposition (A=UP) for 3x3 matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
Expand All @@ -100,13 +100,13 @@ def polar_decompose3d(A, dt):
Returns:
Decomposed 3x3 matrices `U` and `P`.
"""
U, sig, V = svd(A, dt)
U, sig, V = _svd3d(A, dt)
return U @ V.transpose(), V @ sig @ V.transpose()


# https://www.seas.upenn.edu/~cffjiang/research/svd/svd.pdf
@func
def svd2d(A, dt):
def _svd2d(A, dt):
"""Perform singular value decomposition (A=USV^T) for 2x2 matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
Expand All @@ -118,7 +118,7 @@ def svd2d(A, dt):
Returns:
Decomposed 2x2 matrices `U`, 'S' and `V`.
"""
R, S = polar_decompose2d(A, dt)
R, S = _polar_decompose2d(A, dt)
c, s = ops.cast(0.0, dt), ops.cast(0.0, dt)
s1, s2 = ops.cast(0.0, dt), ops.cast(0.0, dt)
if abs(S[0, 1]) < 1e-5:
Expand Down Expand Up @@ -148,7 +148,7 @@ def svd2d(A, dt):
return U, Matrix([[s1, ops.cast(0, dt)], [ops.cast(0, dt), s2]], dt=dt), V


def svd3d(A, dt, iters=None):
def _svd3d(A, dt, iters=None):
"""Perform singular value decomposition (A=USV^T) for 3x3 matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
Expand All @@ -162,7 +162,10 @@ def svd3d(A, dt, iters=None):
Decomposed 3x3 matrices `U`, 'S' and `V`.
"""
assert A.n == 3 and A.m == 3
inputs = tuple([e.ptr for e in A.entries])
if impl.current_cfg().real_matrix:
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
else:
inputs = tuple([e.ptr for e in A.entries])
assert dt in [f32, f64]
if iters is None:
if dt == f32:
Expand All @@ -179,15 +182,20 @@ def svd3d(A, dt, iters=None):
U_entries = rets[:9]
V_entries = rets[9:18]
sig_entries = rets[18:]
U = expr_init(Matrix.zero(dt, 3, 3))
V = expr_init(Matrix.zero(dt, 3, 3))
sigma = expr_init(Matrix.zero(dt, 3, 3))
for i in range(3):
for j in range(3):
U(i, j)._assign(U_entries[i * 3 + j])
V(i, j)._assign(V_entries[i * 3 + j])
sigma(i, i)._assign(sig_entries[i])
return U, sigma, V

@func
def get_result():
U = Matrix.zero(dt, 3, 3)
V = Matrix.zero(dt, 3, 3)
sigma = Matrix.zero(dt, 3, 3)
for i in static(range(3)):
for j in static(range(3)):
U[i, j] = U_entries[i * 3 + j]
V[i, j] = V_entries[i * 3 + j]
sigma[i, i] = sig_entries[i]
return U, sigma, V

return get_result()


@func
Expand Down Expand Up @@ -356,55 +364,10 @@ def _sym_eig3x3(A, dt):
return eigenvalues, Q


@func
def _svd(A, dt):
"""Perform singular value decomposition (A=USV^T) for arbitrary size matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
2D implementation refers to :func:`taichi.svd2d`.
3D implementation refers to :func:`taichi.svd3d`.
Args:
A (ti.Matrix(n, n)): input nxn matrix `A`.
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
Returns:
Decomposed nxn matrices `U`, 'S' and `V`.
"""
if static(A.n == 2): # pylint: disable=R1705
ret = svd2d(A, dt)
return ret
else:
return svd3d(A, dt)


@func
def _polar_decompose(A, dt):
"""Perform polar decomposition (A=UP) for arbitrary size matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
2D implementation refers to :func:`taichi.polar_decompose2d`.
3D implementation refers to :func:`taichi.polar_decompose3d`.
Args:
A (ti.Matrix(n, n)): input nxn matrix `A`.
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
Returns:
Decomposed nxn matrices `U` and `P`.
"""
if static(A.n == 2): # pylint: disable=R1705
ret = polar_decompose2d(A, dt)
return ret
else:
return polar_decompose3d(A, dt)


def polar_decompose(A, dt=None):
"""Perform polar decomposition (A=UP) for arbitrary size matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
This is only a wrapper for :func:`taichi.polar_decompose`.
Args:
A (ti.Matrix(n, n)): input nxn matrix `A`.
Expand All @@ -415,17 +378,17 @@ def polar_decompose(A, dt=None):
"""
if dt is None:
dt = impl.get_runtime().default_fp
if A.n != 2 and A.n != 3:
raise Exception(
"Polar decomposition only supports 2D and 3D matrices.")
return _polar_decompose(A, dt)
if A.n == 2:
return _polar_decompose2d(A, dt)
if A.n == 3:
return _polar_decompose3d(A, dt)
raise Exception("Polar decomposition only supports 2D and 3D matrices.")


def svd(A, dt=None):
"""Perform singular value decomposition (A=USV^T) for arbitrary size matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
This is only a wrappers for :func:`taichi.svd`.
Args:
A (ti.Matrix(n, n)): input nxn matrix `A`.
Expand All @@ -436,16 +399,17 @@ def svd(A, dt=None):
"""
if dt is None:
dt = impl.get_runtime().default_fp
if A.n != 2 and A.n != 3:
raise Exception("SVD only supports 2D and 3D matrices.")
return _svd(A, dt)
if A.n == 2:
return _svd2d(A, dt)
if A.n == 3:
return _svd3d(A, dt)
raise Exception("SVD only supports 2D and 3D matrices.")


def eig(A, dt=None):
"""Compute the eigenvalues and right eigenvectors of a real matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
2D implementation refers to :func:`taichi.eig2x2`.
Args:
A (ti.Matrix(n, n)): 2D Matrix for which the eigenvalues and right eigenvectors will be computed.
Expand All @@ -466,7 +430,6 @@ def sym_eig(A, dt=None):
"""Compute the eigenvalues and right eigenvectors of a real symmetric matrix.
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
2D implementation refers to :func:`taichi.sym_eig2x2`.
Args:
A (ti.Matrix(n, n)): Symmetric Matrix for which the eigenvalues and right eigenvectors will be computed.
Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,24 @@ def test_polar_decomp_f64(dim):
_test_polar_decomp(dim, ti.f64)


@pytest.mark.parametrize("dim", [2, 3])
@test_utils.test(default_fp=ti.f32,
exclude=ti.opengl,
real_matrix=True,
real_matrix_scalarize=True)
def test_polar_decomp_f32_real_matrix_scalarize(dim):
_test_polar_decomp(dim, ti.f32)


@pytest.mark.parametrize("dim", [2, 3])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
real_matrix=True,
real_matrix_scalarize=True)
def test_polar_decomp_f64_real_matrix_scalarize(dim):
_test_polar_decomp(dim, ti.f64)


@test_utils.test()
def test_matrix():
x = ti.Matrix.field(2, 2, dtype=ti.i32)
Expand Down
32 changes: 30 additions & 2 deletions tests/python/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,26 @@ def test_svd_f64(dim):
_test_svd(ti.f64, dim)


@test_utils.test()
def test_transpose_no_loop():
@pytest.mark.parametrize("dim", [2, 3])
@test_utils.test(default_fp=ti.f32,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_svd_f32_real_matrix_scalarize(dim):
_test_svd(ti.f32, dim)


@pytest.mark.parametrize("dim", [2, 3])
@test_utils.test(require=ti.extension.data64,
default_fp=ti.f64,
fast_math=False,
real_matrix=True,
real_matrix_scalarize=True)
def test_svd_f64_real_matrix_scalarize(dim):
_test_svd(ti.f64, dim)


def _test_transpose_no_loop():
A = ti.Matrix.field(3, 3, dtype=ti.f32, shape=())
U = ti.Matrix.field(3, 3, dtype=ti.f32, shape=())
sigma = ti.Matrix.field(3, 3, dtype=ti.f32, shape=())
Expand All @@ -89,3 +107,13 @@ def run():

run()
# As long as it passes compilation we are good


@test_utils.test()
def test_transpose_no_loop():
_test_transpose_no_loop()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_transpose_no_loop_real_matrix_scalarize():
_test_transpose_no_loop()

0 comments on commit db6ab0d

Please sign in to comment.