Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] MatrixType refactor part 2: add more ops #6425

Merged
merged 26 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ def build_Attribute(ctx, node):
def build_BinOp(ctx, node):
build_stmt(ctx, node.left)
build_stmt(ctx, node.right)
# pylint: disable-msg=C0415
from taichi.lang.matrix_ops import matmul
op = {
ast.Add: lambda l, r: l + r,
ast.Sub: lambda l, r: l - r,
Expand All @@ -801,7 +803,7 @@ def build_BinOp(ctx, node):
ast.BitOr: lambda l, r: l | r,
ast.BitXor: lambda l, r: l ^ r,
ast.BitAnd: lambda l, r: l & r,
ast.MatMult: lambda l, r: l @ r,
ast.MatMult: matmul,
}.get(type(node.op))
try:
node.ptr = op(node.left.ptr, node.right.ptr)
Expand Down
101 changes: 45 additions & 56 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,14 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None):
elif isinstance(arr[0], Matrix):
raise Exception('cols/rows required when using list of vectors')
else:
is_matrix = isinstance(arr[0], Iterable) and not is_vector(self)
if ndim is not None:
self.ndim = ndim
is_matrix = ndim == 2
else:
is_matrix = isinstance(arr[0],
Iterable) and not is_vector(self)
self.ndim = 2 if is_matrix else 1
initializer = _make_entries_initializer(is_matrix)
self.ndim = 2 if is_matrix else 1
if not is_matrix and isinstance(arr[0], Iterable):
flattened = []
for row in arr:
Expand Down Expand Up @@ -483,6 +488,8 @@ def get_shape(self):

def element_type(self):
if self._impl.entries:
if in_python_scope():
return type(self._impl.entries[0])
return getattr(self._impl.entries[0], 'element_type',
lambda: None)()
return None
Expand Down Expand Up @@ -781,11 +788,9 @@ def normalized(self, eps=0):
>>> a.normalized()
[0.6, 0.8]
"""
impl.static(
impl.static_assert(self.m == 1,
"normalized() only works on vector"))
invlen = 1 / (self.norm() + eps)
return invlen * self
# pylint: disable-msg=C0415
from taichi.lang import matrix_ops
return matrix_ops.normalized(self, eps)

def transpose(self):
"""Returns the transpose of a matrix.
Expand All @@ -799,8 +804,9 @@ def transpose(self):
>>> A.transpose()
[[0, 2], [1, 3]]
"""
from taichi._funcs import _matrix_transpose # pylint: disable=C0415
return _matrix_transpose(self)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.transpose(self)

@taichi_scope
def determinant(a):
Expand All @@ -815,33 +821,9 @@ def determinant(a):
Raises:
Exception: Determinants of matrices with sizes >= 5 are not supported.
"""
if a.n == 1 and a.m == 1:
return a(0, 0)
if a.n == 2 and a.m == 2:
return a(0, 0) * a(1, 1) - a(0, 1) * a(1, 0)
if a.n == 3 and a.m == 3:
return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a(
1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a(
2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2))
if a.n == 4 and a.m == 4:
n = 4

def E(x, y):
return a(x % n, y % n)

det = impl.expr_init(0.0)
for i in range(4):
det = det + (-1.0)**i * (
a(i, 0) *
(E(i + 1, 1) *
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
E(i + 2, 1) *
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
E(i + 3, 1) *
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
return det
raise Exception(
"Determinants of matrices with sizes >= 5 are not supported")
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.determinant(a)

@staticmethod
def diag(dim, val):
Expand All @@ -862,9 +844,9 @@ def diag(dim, val):
[0, 1, 0],
[0, 0, 1]]
"""
# TODO: need a more systematic way to create a "0" with the right type
return Matrix([[val if i == j else 0 * val for j in range(dim)]
for i in range(dim)])
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.diag(dim, val)

def sum(self):
"""Return the sum of all elements.
Expand All @@ -875,10 +857,9 @@ def sum(self):
>>> m.sum()
10
"""
ret = self.entries[0]
for i in range(1, len(self.entries)):
ret = ret + self.entries[i]
return ret
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.sum(self)

def norm(self, eps=0):
"""Returns the square root of the sum of the absolute squares
Expand All @@ -896,7 +877,9 @@ def norm(self, eps=0):
Returns:
The square root of the sum of the absolute squares of its elements.
"""
return ops_mod.sqrt(self.norm_sqr() + eps)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm(self, eps=eps)

def norm_inv(self, eps=0):
"""The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`.
Expand All @@ -907,19 +890,27 @@ def norm_inv(self, eps=0):
Returns:
The inverse of the matrix/vector `norm`.
"""
return ops_mod.rsqrt(self.norm_sqr() + eps)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm_inv(self, eps=eps)

def norm_sqr(self):
"""Returns the sum of the absolute squares of its elements."""
return (self * self).sum()
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.norm_sqr(self)

def max(self):
"""Returns the maximum element value."""
return ops_mod.max(*self.entries)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.max(self)

def min(self):
"""Returns the minimum element value."""
return ops_mod.min(*self.entries)
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.min(self)

def any(self):
"""Test whether any element not equal zero.
Expand All @@ -933,10 +924,9 @@ def any(self):
>>> v.any()
True
"""
ret = False
for entry in self.entries:
ret = ret | ops_mod.cmp_ne(entry, 0)
return ret & True
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.any(self)

def all(self):
"""Test whether all element not equal zero.
Expand All @@ -950,10 +940,9 @@ def all(self):
>>> v.all()
False
"""
ret = True
for entry in self.entries:
ret = ret & ops_mod.cmp_ne(entry, 0)
return ret
# pylint: disable=C0415
from taichi.lang import matrix_ops
return matrix_ops.all(self)

@taichi_scope
def fill(self, val):
Expand Down
Loading