From 9665e88918f1d9bb3098e3ab6188451d6305f264 Mon Sep 17 00:00:00 2001 From: Mike He Date: Fri, 11 Nov 2022 23:10:59 -0500 Subject: [PATCH] [lang] Add fused foreach check (#6525) ### Brief Summary Foreach check in PR #6425 is verbose. This PR simplifies the usage. Co-authored-by: Yi Xu --- python/taichi/lang/matrix_ops.py | 17 +++--- python/taichi/lang/matrix_ops_utils.py | 75 ++++++++++++++------------ 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 5ccfbfe4495ad..0062a7cf7e09e 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -5,9 +5,9 @@ from taichi.lang.impl import static from taichi.lang.kernel_impl import func, pyfunc from taichi.lang.matrix import Matrix, Vector -from taichi.lang.matrix_ops_utils import (Or, arg_at, assert_list, - assert_tensor, assert_vector, - check_matmul, dim_lt, foreach, +from taichi.lang.matrix_ops_utils import (arg_at, arg_foreach_check, + assert_list, assert_tensor, + assert_vector, check_matmul, dim_lt, is_int_const, preconditions, same_shapes, square_matrix) from taichi.lang.util import cook_dtype @@ -79,13 +79,12 @@ def _rotation2d_matrix(alpha): @preconditions( - arg_at( + arg_at(0, lambda xs: same_shapes(*xs)), + arg_foreach_check( 0, - foreach( - Or(assert_vector(), - assert_list, - msg="Cols/rows must be a list of lists, or a list of vectors")), - same_shapes)) + fns=[assert_vector(), assert_list], + logic='or', + msg="Cols/rows must be a list of lists, or a list of vectors")) def rows(rows): # pylint: disable=W0621 if isinstance(rows[0], (Matrix, Expr)): shape = rows[0].get_shape() diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 4fa5d2269c86f..9a8c7458bff45 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -27,27 +27,17 @@ def wrapper(*args, **kwargs): return decorator -def arg_at(i, *fns): +def arg_at(indices, *fns): def check(*args, **kwargs): - if i in kwargs: - arg = kwargs[i] - else: - try: + nonlocal indices + if isinstance(indices, int): + indices = [indices] + for i in indices: + if i in kwargs: + arg = kwargs[i] + else: arg = args[i] - except IndexError: - raise - ok, msg = do_check(fns, arg) - if not ok: - return False, msg - return True, None - - return check - - -def foreach(*fns): - def check(args): - for x in args: - ok, msg = do_check(fns, x) + ok, msg = do_check(fns, arg) if not ok: return False, msg return True, None @@ -55,28 +45,14 @@ def check(args): return check -def Or(f, g, msg=None): - def check(*args, **kwargs): - ok, msg_f = do_check([f], *args, **kwargs) - if not ok: - ok, msg_g = do_check([g], *args, **kwargs) - if not ok: - return False, f'Both violated: {msg_f} {msg_g}' - return True, None - - return check - - def assert_tensor(m, msg='not tensor type: {}'): if isinstance(m, Matrix): return True, None if isinstance(m, Expr) and m.is_tensor(): return True, None - raise TaichiCompilationError(msg.format(type(m))) + return False, msg.format(type(m)) -# TODO(zhanlue): rearrange to more generic checker functions -# for example: "assert_is_instance(args, indices=[], instances=[], logic='or')" def assert_vector(msg='expected a vector, got {}'): def check(v): if (isinstance(v, Expr) or isinstance(v, Matrix)) and len( @@ -90,7 +66,36 @@ def check(v): def assert_list(x, msg='not a list: {}'): if isinstance(x, list): return True, None - raise TaichiCompilationError(msg.format(type(x))) + return False, msg.format(type(x)) + + +def arg_foreach_check(*arg_indices, fns=[], logic='or', msg=None): + def check(*args, **kwargs): + for i in arg_indices: + if i in kwargs: + arg = kwargs[i] + else: + arg = args[i] + if logic == 'or': + for a in arg: + passed = False + for fn in fns: + ok, _ = do_check([fn], a) + if ok: + passed = True + break + if not passed: + return False, msg + elif logic == 'and': + for a in arg: + ok, _ = do_check(fns, a) + if not ok: + return False, msg + else: + raise ValueError(f'Unknown logic: {logic}') + return True, None + + return check def same_shapes(*xs):