Skip to content

Commit

Permalink
[Lang] MatrixType refactor: Support inverse() (#6542)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Nov 11, 2022
1 parent 51472bd commit 75acd5d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 89 deletions.
59 changes: 11 additions & 48 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ def prop_setter(instance, value):


def make_matrix(arr, dt=None):
assert len(arr) > 0, "Cannot create empty matrix"
is_matrix = isinstance(arr[0], Iterable)
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
if len(arr) == 0:
# the only usage of an empty vector is to serve as field indices
is_matrix = False
dt = primitive_types.i32
else:
dt = cook_dtype(dt)
is_matrix = isinstance(arr[0], Iterable)
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
else:
dt = cook_dtype(dt)
if not is_matrix:
return impl.Expr(
impl.make_matrix_expr([len(arr)], dt,
Expand Down Expand Up @@ -727,7 +731,6 @@ def trace(self):
from taichi.lang import matrix_ops
return matrix_ops.trace(self)

@taichi_scope
def inverse(self):
"""Returns the inverse of this matrix.
Expand All @@ -740,48 +743,8 @@ def inverse(self):
Raises:
Exception: Inversions of matrices with sizes >= 5 are not supported.
"""
assert self.n == self.m, 'Only square matrices are invertible'
if self.n == 1:
return Matrix([1 / self(0, 0)])
if self.n == 2:
inv_determinant = impl.expr_init(1.0 / self.determinant())
return inv_determinant * Matrix([[self(
1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]])
if self.n == 3:
n = 3
inv_determinant = impl.expr_init(1.0 / self.determinant())
entries = [[0] * n for _ in range(n)]

def E(x, y):
return self(x % n, y % n)

for i in range(n):
for j in range(n):
entries[j][i] = inv_determinant * (
E(i + 1, j + 1) * E(i + 2, j + 2) -
E(i + 2, j + 1) * E(i + 1, j + 2))
return Matrix(entries)
if self.n == 4:
n = 4
inv_determinant = impl.expr_init(1.0 / self.determinant())
entries = [[0] * n for _ in range(n)]

def E(x, y):
return self(x % n, y % n)

for i in range(n):
for j in range(n):
entries[j][i] = inv_determinant * (-1)**(i + j) * ((
E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3))))
return Matrix(entries)
raise Exception(
"Inversions of matrices with sizes >= 5 are not supported")
from taichi.lang import matrix_ops # pylint: disable=C0415
return matrix_ops.inverse(self)

def normalized(self, eps=0):
"""Normalize a vector, i.e. matrices with the second dimension being
Expand Down
93 changes: 61 additions & 32 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,46 +112,75 @@ def cols(cols): # pylint: disable=W0621
return rows(cols).transpose()


def E(m, x, y, n):
@func
def _E():
return m[x % n, y % n]

return _E()
@pyfunc
def E(mat: template(), x: template(), y: template(), n: template()):
return mat[x % n, y % n]


@preconditions(square_matrix,
dim_lt(0, 5,
'Determinant of dimension >= 5 is not supported: {}'))
@func
def determinant(x):
shape = static(x.get_shape())
if static(shape[0] == 1 and shape[1] == 1):
return x[0, 0]
if static(shape[0] == 2 and shape[1] == 2):
return x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0]
if static(shape[0] == 3 and shape[1] == 3):
return x[0, 0] * (x[1, 1] * x[2, 2] - x[2, 1] * x[1, 2]) - x[1, 0] * (
x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * (
x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2])
if static(shape[0] == 4 and shape[1] == 4):

det = 0.0
@preconditions(square_matrix, dim_lt(0, 5))
@pyfunc
def determinant(mat):
shape = static(mat.get_shape())
if static(shape[0] == 1):
return mat[0, 0]
if static(shape[0] == 2):
return mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]
if static(shape[0] == 3):
return mat[0, 0] * (
mat[1, 1] * mat[2, 2] - mat[2, 1] * mat[1, 2]) - mat[1, 0] * (
mat[0, 1] * mat[2, 2] - mat[2, 1] * mat[0, 2]) + mat[2, 0] * (
mat[0, 1] * mat[1, 2] - mat[1, 1] * mat[0, 2])
if static(shape[0] == 4):
det = mat[0, 0] * 0 # keep type
for i in static(range(4)):
det += (-1.0)**i * (
x[i, 0] *
(E(x, i + 1, 1, 4) *
(E(x, i + 2, 2, 4) * E(x, i + 3, 3, 4) -
E(x, i + 3, 2, 4) * E(x, i + 2, 3, 4)) - E(x, i + 2, 1, 4) *
(E(x, i + 1, 2, 4) * E(x, i + 3, 3, 4) -
E(x, i + 3, 2, 4) * E(x, i + 1, 3, 4)) + E(x, i + 3, 1, 4) *
(E(x, i + 1, 2, 4) * E(x, i + 2, 3, 4) -
E(x, i + 2, 2, 4) * E(x, i + 1, 3, 4))))
det += (-1)**i * (mat[i, 0] *
(E(mat, i + 1, 1, 4) *
(E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4)) -
E(mat, i + 2, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4)) +
E(mat, i + 3, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) -
E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))))
return det
# unreachable
return None


@preconditions(square_matrix, dim_lt(0, 5))
@pyfunc
def inverse(mat):
shape = static(mat.get_shape())
if static(shape[0] == 1):
return Matrix([[1.0 / mat[0, 0]]])
inv_determinant = 1.0 / determinant(mat)
if static(shape[0] == 2):
return inv_determinant * Matrix([[mat[1, 1], -mat[0, 1]],
[-mat[1, 0], mat[0, 0]]])
if static(shape[0] == 3):
return inv_determinant * Matrix([[
E(mat, i + 1, j + 1, 3) * E(mat, i + 2, j + 2, 3) -
E(mat, i + 2, j + 1, 3) * E(mat, i + 1, j + 2, 3)
for i in static(range(3))
] for j in static(range(3))])
if static(shape[0] == 4):
return inv_determinant * Matrix([[(-1)**(i + j) * (
(E(mat, i + 1, j + 1, 4) *
(E(mat, i + 2, j + 2, 4) * E(mat, i + 3, j + 3, 4) -
E(mat, i + 3, j + 2, 4) * E(mat, i + 2, j + 3, 4)) -
E(mat, i + 2, j + 1, 4) *
(E(mat, i + 1, j + 2, 4) * E(mat, i + 3, j + 3, 4) -
E(mat, i + 3, j + 2, 4) * E(mat, i + 1, j + 3, 4)) +
E(mat, i + 3, j + 1, 4) *
(E(mat, i + 1, j + 2, 4) * E(mat, i + 2, j + 3, 4) -
E(mat, i + 2, j + 2, 4) * E(mat, i + 1, j + 3, 4))))
for i in static(range(4))]
for j in static(range(4))])
# unreachable
return None


@preconditions(assert_tensor)
@pyfunc
def transpose(mat):
Expand Down
7 changes: 3 additions & 4 deletions python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,16 @@ def square_matrix(x):
assert_tensor(x)
shape = x.get_shape()
if len(shape) != 2 or shape[0] != shape[1]:
return False, f'not a square matrix: {shape}'
return False, f'expected a square matrix, got shape {shape}'
return True, None


def dim_lt(dim, limit, msg=None):
def dim_lt(dim, limit):
def check(x):
assert_tensor(x)
shape = x.get_shape()
return shape[dim] < limit, (
f'Dimension >= {limit} is not supported: {shape}'
if not msg else msg.format(shape))
f'only dimension < {limit} is supported, got shape {shape}')

return check

Expand Down
16 changes: 13 additions & 3 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ def inc():
assert x[i][1, 1] == 1 + i


@pytest.mark.parametrize("n", range(1, 5))
@test_utils.test()
def test_mat_inverse_size(n):
def _test_mat_inverse_size(n):
m = ti.Matrix.field(n, n, dtype=ti.f32, shape=())
M = np.empty(shape=(n, n), dtype=np.float32)
for i in range(n):
Expand All @@ -269,6 +267,18 @@ def invert():
np.testing.assert_almost_equal(m_np, np.linalg.inv(M))


@pytest.mark.parametrize("n", range(1, 5))
@test_utils.test()
def test_mat_inverse_size(n):
_test_mat_inverse_size(n)


@pytest.mark.parametrize("n", range(1, 5))
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_mat_inverse_size_real_matrix_scalarize(n):
_test_mat_inverse_size(n)


def _test_matrix_factories():
a = ti.Vector.field(3, dtype=ti.i32, shape=3)
b = ti.Matrix.field(2, 2, dtype=ti.f32, shape=2)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def test_fun() -> ti.f32:
assert np.abs(x.trace() - 7.1) < 1e-6

with pytest.raises(TaichiCompilationError,
match=r"not a square matrix: \(3, 2\)"):
match=r"expected a square matrix, got shape \(3, 2\)"):
x = ti.Matrix([[.1, 3.], [5., 7.], [1., 2.]])
print(x.trace())

Expand All @@ -1102,7 +1102,7 @@ def failed_func():
print(x.trace())

with pytest.raises(TaichiCompilationError,
match=r"not a square matrix: \(3, 2\)"):
match=r"expected a square matrix, got shape \(3, 2\)"):
failed_func()


Expand Down

0 comments on commit 75acd5d

Please sign in to comment.