From 5d6ef2a3a17db3d64024934562bfb5009f2fccdb Mon Sep 17 00:00:00 2001 From: Kenneth Lozes Date: Tue, 26 May 2020 09:44:04 -0700 Subject: [PATCH 1/6] [skip ci] add any and all operators to Matrix --- python/taichi/lang/matrix.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index c351a4ab57c2a..52b7977fa9bf2 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -482,6 +482,20 @@ def min(self): ret = impl.min(ret, self.entries[i]) return ret + def any(self): + import taichi as ti + ret = (self.entries[0] != ti.expr_init(0)) + for i in range(1, len(self.entries)): + ret = ret + (self.entries[i] != ti.expr_init(0)) + return -(ret < ti.expr_init(0)) + + def all(self): + import taichi as ti + ret = self.entries[0] != ti.expr_init(0) + for i in range(1, len(self.entries)): + ret = ret + (self.entries[i] != ti.expr_init(0)) + return -(ret == ti.expr_init(-len(self.entries))) + def dot(self, other): assert self.m == 1 and other.m == 1 return (self.transposed(self) @ other).subscript(0, 0) From 599a5dacb2293cf978dba2b6b60549dbcf0bd781 Mon Sep 17 00:00:00 2001 From: Kenneth Lozes Date: Tue, 26 May 2020 10:22:01 -0700 Subject: [PATCH 2/6] [skip ci] add test --- tests/python/test_linalg.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index 69c1fa0a8eb81..6838f3757d172 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -173,3 +173,37 @@ def fill(): assert m2[0][j, i] == int(i + 3 * j + 1) assert m3[0][i, j] == int(i + 3 * j + 1) assert m4[0][j, i] == int(i + 3 * j + 1) + + +@ti.all_archs +def test_any_all(): + a = ti.Matrix(2, 2, dt=ti.i32, shape=()) + b = ti.var(dt=ti.i32, shape=()) + + @ti.kernel + def func_any(): + b[None] = a[None].any() + + @ti.kernel + def func_all(): + b[None] = a[None].all() + + for i in range(2): + for j in range(2): + a[None][0,0] = i + a[None][1,0] = j + a[None][1,1] = i + a[None][0,1] = j + + func_any() + if i==1 or j==1: + assert b[None] == 1 + else: + assert b[None] == 0 + + + func_all() + if i==1 and j==1: + assert b[None] == 1 + else: + assert b[None] == 0 From afe45be58d2a64bed450c3e4bed7241930d58066 Mon Sep 17 00:00:00 2001 From: Kenneth Lozes Date: Tue, 26 May 2020 11:02:15 -0700 Subject: [PATCH 3/6] [skip ci] move __ti_int__ and __ti_float__ to common_ops --- python/taichi/lang/common_ops.py | 8 ++++++++ python/taichi/lang/expr.py | 8 -------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index b1cee1595efc1..7cae176b68e2e 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -110,3 +110,11 @@ def __invert__(self): # ~a => a.__invert__() def __not__(self): # not a => a.__not__() import taichi as ti return ti.logical_not(self) + + def __ti_int__(self): + import taichi as ti + return ti.cast(self, ti.get_runtime().default_ip) + + def __ti_float__(self): + import taichi as ti + return ti.cast(self, ti.get_runtime().default_fp) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 9792b20b80913..c6acaa895fd1f 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -223,14 +223,6 @@ def fill(self, val): from .meta import fill_tensor fill_tensor(self, val) - def __ti_int__(self): - import taichi as ti - return ti.cast(self, ti.get_runtime().default_ip) - - def __ti_float__(self): - import taichi as ti - return ti.cast(self, ti.get_runtime().default_fp) - def parent(self, n=1): import taichi as ti p = self.ptr.snode() From a9da0f198c6950ca2200aa411fe3c41c429ec191 Mon Sep 17 00:00:00 2001 From: Kenneth Lozes Date: Tue, 26 May 2020 11:46:26 -0700 Subject: [PATCH 4/6] [skip ci] overload builtin any and all ops --- python/taichi/lang/ops.py | 10 ++++++++++ python/taichi/lang/transformer.py | 4 ++++ tests/python/test_linalg.py | 4 ++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index b9e047ae603a0..d98c5fc7f1fe8 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -342,6 +342,16 @@ def ti_min(*args): return ti_min(args[0], ti_min(*args[1:])) +def ti_any(a): + assert hasattr(a, 'any') + return a.any() + + +def ti_all(a): + assert hasattr(a, 'all') + return a.all() + + def append(l, indices, val): import taichi as ti a = ti.expr_init( diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 5cf54268d0b42..a63c8d412af82 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -554,6 +554,10 @@ def visit_Call(self, node): node.func = self.parse_expr('ti.ti_int') elif func_name == 'float': node.func = self.parse_expr('ti.ti_float') + elif func_name == 'any': + node.func = self.parse_expr('ti.ti_any') + elif func_name == 'all': + node.func = self.parse_expr('ti.ti_all') else: pass return node diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index 6838f3757d172..c9b4fadbdde91 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -182,11 +182,11 @@ def test_any_all(): @ti.kernel def func_any(): - b[None] = a[None].any() + b[None] = any(a[None]) @ti.kernel def func_all(): - b[None] = a[None].all() + b[None] = all(a[None]) for i in range(2): for j in range(2): From 4110965c98c6c5cb539abf21a035c2e18b589b71 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 26 May 2020 15:21:22 -0400 Subject: [PATCH 5/6] [skip ci] enforce code format --- python/taichi/lang/ops.py | 4 ++-- tests/python/test_linalg.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index d98c5fc7f1fe8..13b12f4ac413a 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -343,12 +343,12 @@ def ti_min(*args): def ti_any(a): - assert hasattr(a, 'any') + assert hasattr(a, 'any') return a.any() def ti_all(a): - assert hasattr(a, 'all') + assert hasattr(a, 'all') return a.all() diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index c9b4fadbdde91..5ca046208b3e8 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -190,20 +190,19 @@ def func_all(): for i in range(2): for j in range(2): - a[None][0,0] = i - a[None][1,0] = j - a[None][1,1] = i - a[None][0,1] = j + a[None][0, 0] = i + a[None][1, 0] = j + a[None][1, 1] = i + a[None][0, 1] = j func_any() - if i==1 or j==1: + if i == 1 or j == 1: assert b[None] == 1 else: assert b[None] == 0 - func_all() - if i==1 and j==1: + if i == 1 and j == 1: assert b[None] == 1 else: assert b[None] == 0 From c82c8792cef715ab8b8cb6f9c05932a2fed8ed60 Mon Sep 17 00:00:00 2001 From: Kenneth Lozes Date: Tue, 26 May 2020 15:30:06 -0700 Subject: [PATCH 6/6] [skip ci] update doc --- docs/matrix.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/matrix.rst b/docs/matrix.rst index a7318211619d1..d31ab273a9666 100644 --- a/docs/matrix.rst +++ b/docs/matrix.rst @@ -13,9 +13,11 @@ Matrices - ``ti.tr(A)`` - ``ti.determinant(A, type)`` - ``ti.cross(a, b)``, where ``a`` and ``b`` are 3D vectors (i.e. ``3x1`` matrices) -- ``A.cast(type)`` +- ``A.cast(type)`` or simply ``int(A)`` and ``float(A)`` - ``R, S = ti.polar_decompose(A, ti.f32)`` - ``U, sigma, V = ti.svd(A, ti.f32)`` (Note that ``sigma`` is a ``3x3`` diagonal matrix) +- ``any(A)`` +- ``all(A)`` TODO: doc here better like Vector. WIP