Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Add any()/all() functions for Matrix, move casting ops to common_ops #1064

Merged
merged 8 commits into from
May 26, 2020
8 changes: 8 additions & 0 deletions python/taichi/lang/common_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,11 @@ def __invert__(self): # ~a => a.__invert__()
def __not__(self): # not a => a.__not__()
import taichi as ti
return ti.logical_not(self)

def __ti_int__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_ip)

def __ti_float__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_fp)
8 changes: 0 additions & 8 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ def fill(self, val):
from .meta import fill_tensor
fill_tensor(self, val)

def __ti_int__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_ip)

def __ti_float__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_fp)

def parent(self, n=1):
import taichi as ti
p = self.ptr.snode()
Expand Down
14 changes: 14 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,20 @@ def min(self):
ret = impl.min(ret, self.entries[i])
return ret

def any(self):
import taichi as ti
ret = (self.entries[0] != ti.expr_init(0))
for i in range(1, len(self.entries)):
ret = ret + (self.entries[i] != ti.expr_init(0))
return -(ret < ti.expr_init(0))

def all(self):
import taichi as ti
ret = self.entries[0] != ti.expr_init(0)
for i in range(1, len(self.entries)):
ret = ret + (self.entries[i] != ti.expr_init(0))
return -(ret == ti.expr_init(-len(self.entries)))

def dot(self, other):
assert self.m == 1 and other.m == 1
return (self.transposed(self) @ other).subscript(0, 0)
Expand Down
10 changes: 10 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,16 @@ def ti_min(*args):
return ti_min(args[0], ti_min(*args[1:]))


def ti_any(a):
assert hasattr(a, 'any')
return a.any()


def ti_all(a):
assert hasattr(a, 'all')
return a.all()


def append(l, indices, val):
import taichi as ti
a = ti.expr_init(
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ def visit_Call(self, node):
node.func = self.parse_expr('ti.ti_int')
elif func_name == 'float':
node.func = self.parse_expr('ti.ti_float')
elif func_name == 'any':
node.func = self.parse_expr('ti.ti_any')
elif func_name == 'all':
node.func = self.parse_expr('ti.ti_all')
else:
pass
return node
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,36 @@ def fill():
assert m2[0][j, i] == int(i + 3 * j + 1)
assert m3[0][i, j] == int(i + 3 * j + 1)
assert m4[0][j, i] == int(i + 3 * j + 1)


@ti.all_archs
def test_any_all():
a = ti.Matrix(2, 2, dt=ti.i32, shape=())
b = ti.var(dt=ti.i32, shape=())

@ti.kernel
def func_any():
b[None] = any(a[None])

@ti.kernel
def func_all():
b[None] = all(a[None])

for i in range(2):
for j in range(2):
a[None][0, 0] = i
a[None][1, 0] = j
a[None][1, 1] = i
a[None][0, 1] = j

func_any()
if i == 1 or j == 1:
assert b[None] == 1
else:
assert b[None] == 0

func_all()
if i == 1 and j == 1:
assert b[None] == 1
else:
assert b[None] == 0