diff --git a/docs/matrix.rst b/docs/matrix.rst index a7318211619d1..d31ab273a9666 100644 --- a/docs/matrix.rst +++ b/docs/matrix.rst @@ -13,9 +13,11 @@ Matrices - ``ti.tr(A)`` - ``ti.determinant(A, type)`` - ``ti.cross(a, b)``, where ``a`` and ``b`` are 3D vectors (i.e. ``3x1`` matrices) -- ``A.cast(type)`` +- ``A.cast(type)`` or simply ``int(A)`` and ``float(A)`` - ``R, S = ti.polar_decompose(A, ti.f32)`` - ``U, sigma, V = ti.svd(A, ti.f32)`` (Note that ``sigma`` is a ``3x3`` diagonal matrix) +- ``any(A)`` +- ``all(A)`` TODO: doc here better like Vector. WIP diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index b1cee1595efc1..7cae176b68e2e 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -110,3 +110,11 @@ def __invert__(self): # ~a => a.__invert__() def __not__(self): # not a => a.__not__() import taichi as ti return ti.logical_not(self) + + def __ti_int__(self): + import taichi as ti + return ti.cast(self, ti.get_runtime().default_ip) + + def __ti_float__(self): + import taichi as ti + return ti.cast(self, ti.get_runtime().default_fp) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 9792b20b80913..c6acaa895fd1f 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -223,14 +223,6 @@ def fill(self, val): from .meta import fill_tensor fill_tensor(self, val) - def __ti_int__(self): - import taichi as ti - return ti.cast(self, ti.get_runtime().default_ip) - - def __ti_float__(self): - import taichi as ti - return ti.cast(self, ti.get_runtime().default_fp) - def parent(self, n=1): import taichi as ti p = self.ptr.snode() diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 8375ce28d6b08..6db47f9aac901 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -482,6 +482,20 @@ def min(self): ret = impl.min(ret, self.entries[i]) return ret + def any(self): + import taichi as ti + ret = (self.entries[0] != ti.expr_init(0)) + for i in range(1, len(self.entries)): + ret = ret + (self.entries[i] != ti.expr_init(0)) + return -(ret < ti.expr_init(0)) + + def all(self): + import taichi as ti + ret = self.entries[0] != ti.expr_init(0) + for i in range(1, len(self.entries)): + ret = ret + (self.entries[i] != ti.expr_init(0)) + return -(ret == ti.expr_init(-len(self.entries))) + def dot(self, other): assert self.m == 1 and other.m == 1 return (self.transposed(self) @ other).subscript(0, 0) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index b9e047ae603a0..13b12f4ac413a 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -342,6 +342,16 @@ def ti_min(*args): return ti_min(args[0], ti_min(*args[1:])) +def ti_any(a): + assert hasattr(a, 'any') + return a.any() + + +def ti_all(a): + assert hasattr(a, 'all') + return a.all() + + def append(l, indices, val): import taichi as ti a = ti.expr_init( diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 5cf54268d0b42..a63c8d412af82 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -554,6 +554,10 @@ def visit_Call(self, node): node.func = self.parse_expr('ti.ti_int') elif func_name == 'float': node.func = self.parse_expr('ti.ti_float') + elif func_name == 'any': + node.func = self.parse_expr('ti.ti_any') + elif func_name == 'all': + node.func = self.parse_expr('ti.ti_all') else: pass return node diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index 69c1fa0a8eb81..5ca046208b3e8 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -173,3 +173,36 @@ def fill(): assert m2[0][j, i] == int(i + 3 * j + 1) assert m3[0][i, j] == int(i + 3 * j + 1) assert m4[0][j, i] == int(i + 3 * j + 1) + + +@ti.all_archs +def test_any_all(): + a = ti.Matrix(2, 2, dt=ti.i32, shape=()) + b = ti.var(dt=ti.i32, shape=()) + + @ti.kernel + def func_any(): + b[None] = any(a[None]) + + @ti.kernel + def func_all(): + b[None] = all(a[None]) + + for i in range(2): + for j in range(2): + a[None][0, 0] = i + a[None][1, 0] = j + a[None][1, 1] = i + a[None][0, 1] = j + + func_any() + if i == 1 or j == 1: + assert b[None] == 1 + else: + assert b[None] == 0 + + func_all() + if i == 1 and j == 1: + assert b[None] == 1 + else: + assert b[None] == 0