Skip to content

Commit

Permalink
[lang] Add fused foreach check (taichi-dev#6525)
Browse files Browse the repository at this point in the history
### Brief Summary

Foreach check in PR taichi-dev#6425 is verbose. This PR simplifies the usage.

Co-authored-by: Yi Xu <[email protected]>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 8a82150 commit 8252721
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
17 changes: 8 additions & 9 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from taichi.lang.impl import static
from taichi.lang.kernel_impl import func, pyfunc
from taichi.lang.matrix import Matrix, Vector
from taichi.lang.matrix_ops_utils import (Or, arg_at, assert_list,
assert_tensor, assert_vector,
check_matmul, dim_lt, foreach,
from taichi.lang.matrix_ops_utils import (arg_at, arg_foreach_check,
assert_list, assert_tensor,
assert_vector, check_matmul, dim_lt,
is_int_const, preconditions,
same_shapes, square_matrix)
from taichi.lang.util import cook_dtype
Expand Down Expand Up @@ -79,13 +79,12 @@ def _rotation2d_matrix(alpha):


@preconditions(
arg_at(
arg_at(0, lambda xs: same_shapes(*xs)),
arg_foreach_check(
0,
foreach(
Or(assert_vector(),
assert_list,
msg="Cols/rows must be a list of lists, or a list of vectors")),
same_shapes))
fns=[assert_vector(), assert_list],
logic='or',
msg="Cols/rows must be a list of lists, or a list of vectors"))
def rows(rows): # pylint: disable=W0621
if isinstance(rows[0], (Matrix, Expr)):
shape = rows[0].get_shape()
Expand Down
75 changes: 40 additions & 35 deletions python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,56 +27,32 @@ def wrapper(*args, **kwargs):
return decorator


def arg_at(i, *fns):
def arg_at(indices, *fns):
def check(*args, **kwargs):
if i in kwargs:
arg = kwargs[i]
else:
try:
nonlocal indices
if isinstance(indices, int):
indices = [indices]
for i in indices:
if i in kwargs:
arg = kwargs[i]
else:
arg = args[i]
except IndexError:
raise
ok, msg = do_check(fns, arg)
if not ok:
return False, msg
return True, None

return check


def foreach(*fns):
def check(args):
for x in args:
ok, msg = do_check(fns, x)
ok, msg = do_check(fns, arg)
if not ok:
return False, msg
return True, None

return check


def Or(f, g, msg=None):
def check(*args, **kwargs):
ok, msg_f = do_check([f], *args, **kwargs)
if not ok:
ok, msg_g = do_check([g], *args, **kwargs)
if not ok:
return False, f'Both violated: {msg_f} {msg_g}'
return True, None

return check


def assert_tensor(m, msg='not tensor type: {}'):
if isinstance(m, Matrix):
return True, None
if isinstance(m, Expr) and m.is_tensor():
return True, None
raise TaichiCompilationError(msg.format(type(m)))
return False, msg.format(type(m))


# TODO(zhanlue): rearrange to more generic checker functions
# for example: "assert_is_instance(args, indices=[], instances=[], logic='or')"
def assert_vector(msg='expected a vector, got {}'):
def check(v):
if (isinstance(v, Expr) or isinstance(v, Matrix)) and len(
Expand All @@ -90,7 +66,36 @@ def check(v):
def assert_list(x, msg='not a list: {}'):
if isinstance(x, list):
return True, None
raise TaichiCompilationError(msg.format(type(x)))
return False, msg.format(type(x))


def arg_foreach_check(*arg_indices, fns=[], logic='or', msg=None):
def check(*args, **kwargs):
for i in arg_indices:
if i in kwargs:
arg = kwargs[i]
else:
arg = args[i]
if logic == 'or':
for a in arg:
passed = False
for fn in fns:
ok, _ = do_check([fn], a)
if ok:
passed = True
break
if not passed:
return False, msg
elif logic == 'and':
for a in arg:
ok, _ = do_check(fns, a)
if not ok:
return False, msg
else:
raise ValueError(f'Unknown logic: {logic}')
return True, None

return check


def same_shapes(*xs):
Expand Down

0 comments on commit 8252721

Please sign in to comment.