Skip to content

Commit

Permalink
[Lang] MatrixType refactor part 2: add more ops (#6425)
Browse files Browse the repository at this point in the history
Issue: #5478 #5819 

### Brief Summary

Co-authored-by: Yi Xu <[email protected]>
Co-authored-by: Zhanlue Yang <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2022
1 parent 2b5eff8 commit ecaf650
Show file tree
Hide file tree
Showing 4 changed files with 469 additions and 71 deletions.
4 changes: 3 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ def build_Attribute(ctx, node):
def build_BinOp(ctx, node):
build_stmt(ctx, node.left)
build_stmt(ctx, node.right)
# pylint: disable-msg=C0415
from taichi.lang.matrix_ops import matmul
op = {
ast.Add: lambda l, r: l + r,
ast.Sub: lambda l, r: l - r,
Expand All @@ -819,7 +821,7 @@ def build_BinOp(ctx, node):
ast.BitOr: lambda l, r: l | r,
ast.BitXor: lambda l, r: l ^ r,
ast.BitAnd: lambda l, r: l & r,
ast.MatMult: lambda l, r: l @ r,
ast.MatMult: matmul,
}.get(type(node.op))
try:
node.ptr = op(node.left.ptr, node.right.ptr)
Expand Down
101 changes: 45 additions & 56 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,14 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None):
elif isinstance(arr[0], Matrix):
raise Exception('cols/rows required when using list of vectors')
else:
is_matrix = isinstance(arr[0], Iterable) and not is_vector(self)
if ndim is not None:
self.ndim = ndim
is_matrix = ndim == 2
else:
is_matrix = isinstance(arr[0],
Iterable) and not is_vector(self)
self.ndim = 2 if is_matrix else 1
initializer = _make_entries_initializer(is_matrix)
self.ndim = 2 if is_matrix else 1
if not is_matrix and isinstance(arr[0], Iterable):
flattened = []
for row in arr:
Expand Down Expand Up @@ -486,6 +491,8 @@ def get_shape(self):

def element_type(self):
if self._impl.entries:
if in_python_scope():
return type(self._impl.entries[0])
return getattr(self._impl.entries[0], 'element_type',
lambda: None)()
return None
Expand Down Expand Up @@ -784,11 +791,9 @@ def normalized(self, eps=0):
>>> a.normalized()
[0.6, 0.8]
"""
impl.static(
impl.static_assert(self.m == 1,
"normalized() only works on vector"))
invlen = 1 / (self.norm() + eps)
return invlen * self
# pylint: disable-msg=C0415
from taichi.lang import matrix_ops
return matrix_ops.normalized(self, eps)

def transpose(self):
"""Returns the transpose of a matrix.
Expand All @@ -802,8 +807,9 @@ def transpose(self):
>>> A.transpose()
[[0, 2], [1, 3]]
"""
from taichi._funcs import _matrix_transpose # pylint: disable=C0415
return _matrix_transpose(self)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.transpose(self)

@taichi_scope
def determinant(a):
Expand All @@ -818,33 +824,9 @@ def determinant(a):
Raises:
Exception: Determinants of matrices with sizes >= 5 are not supported.
"""
if a.n == 1 and a.m == 1:
return a(0, 0)
if a.n == 2 and a.m == 2:
return a(0, 0) * a(1, 1) - a(0, 1) * a(1, 0)
if a.n == 3 and a.m == 3:
return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a(
1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a(
2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2))
if a.n == 4 and a.m == 4:
n = 4

def E(x, y):
return a(x % n, y % n)

det = impl.expr_init(0.0)
for i in range(4):
det = det + (-1.0)**i * (
a(i, 0) *
(E(i + 1, 1) *
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
E(i + 2, 1) *
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
E(i + 3, 1) *
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
return det
raise Exception(
"Determinants of matrices with sizes >= 5 are not supported")
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.determinant(a)

@staticmethod
def diag(dim, val):
Expand All @@ -865,9 +847,9 @@ def diag(dim, val):
[0, 1, 0],
[0, 0, 1]]
"""
# TODO: need a more systematic way to create a "0" with the right type
return Matrix([[val if i == j else 0 * val for j in range(dim)]
for i in range(dim)])
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.diag(dim, val)

def sum(self):
"""Return the sum of all elements.
Expand All @@ -878,10 +860,9 @@ def sum(self):
>>> m.sum()
10
"""
ret = self.entries[0]
for i in range(1, len(self.entries)):
ret = ret + self.entries[i]
return ret
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.sum(self)

def norm(self, eps=0):
"""Returns the square root of the sum of the absolute squares
Expand All @@ -899,7 +880,9 @@ def norm(self, eps=0):
Returns:
The square root of the sum of the absolute squares of its elements.
"""
return ops_mod.sqrt(self.norm_sqr() + eps)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm(self, eps=eps)

def norm_inv(self, eps=0):
"""The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`.
Expand All @@ -910,19 +893,27 @@ def norm_inv(self, eps=0):
Returns:
The inverse of the matrix/vector `norm`.
"""
return ops_mod.rsqrt(self.norm_sqr() + eps)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm_inv(self, eps=eps)

def norm_sqr(self):
"""Returns the sum of the absolute squares of its elements."""
return (self * self).sum()
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm_sqr(self)

def max(self):
"""Returns the maximum element value."""
return ops_mod.max(*self.entries)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.max(self)

def min(self):
"""Returns the minimum element value."""
return ops_mod.min(*self.entries)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.min(self)

def any(self):
"""Test whether any element not equal zero.
Expand All @@ -936,10 +927,9 @@ def any(self):
>>> v.any()
True
"""
ret = False
for entry in self.entries:
ret = ret | ops_mod.cmp_ne(entry, 0)
return ret & True
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.any(self)

def all(self):
"""Test whether all element not equal zero.
Expand All @@ -953,10 +943,9 @@ def all(self):
>>> v.all()
False
"""
ret = True
for entry in self.entries:
ret = ret & ops_mod.cmp_ne(entry, 0)
return ret
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.all(self)

@taichi_scope
def fill(self, val):
Expand Down
Loading

0 comments on commit ecaf650

Please sign in to comment.