diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 011b3227c38eb..5bceab2ee4637 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -112,13 +112,10 @@ def subscript(value, *indices): def chain_compare(comparators, ops): assert len(comparators) == len(ops) + 1, \ f'Chain comparison invoked with {len(comparators)} comparators but {len(ops)} operators' - evaluated_comparators = [] - for i in range(len(comparators)): - evaluated_comparators += [expr_init(comparators[i])] - ret = expr_init(True) + ret = True for i in range(len(ops)): - lhs = evaluated_comparators[i] - rhs = evaluated_comparators[i + 1] + lhs = comparators[i] + rhs = comparators[i + 1] if ops[i] == 'Lt': now = lhs < rhs elif ops[i] == 'LtE': @@ -133,7 +130,7 @@ def chain_compare(comparators, ops): now = lhs != rhs else: assert False, f'Unknown operator {ops[i]}' - ret = ret.logical_and(now) + ret = logical_and(ret, now) return ret diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 20ea0cd594fbe..552e029bcf6d4 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -179,9 +179,9 @@ Expr load(const Expr &ptr) { Expr ptr_if_global(const Expr &var) { if (var.is()) { // singleton global variable - TI_ASSERT_INFO( - var.snode()->num_active_indices == 0, - "Please always use 'x[None]' (instead of simply 'x') to access any 0-D tensor." + TI_ASSERT_INFO(var.snode()->num_active_indices == 0, + "Please always use 'x[None]' (instead of simply 'x') to " + "access any 0-D tensor."); return var[ExprGroup()]; } else { // may be any local or global expr diff --git a/tests/python/test_compare.py b/tests/python/test_compare.py index 61726291e5b5b..616d1fdea8395 100644 --- a/tests/python/test_compare.py +++ b/tests/python/test_compare.py @@ -97,6 +97,27 @@ def func(): assert a[2] # ti.append returns 0 +@ti.all_archs +def test_no_duplicate_eval_func(): + a = ti.var(ti.i32, ()) + b = ti.var(ti.i32, ()) + + @ti.func + def why_this_foo_fail(n): + return ti.atomic_add(b[None], n) + + def foo(n): + return ti.atomic_add(ti.subscript(b, None), n) + + @ti.kernel + def func(): + a[None] = 0 <= foo(2) < 1 + + func() + assert a[None] == 1 + assert b[None] == 2 + + @ti.require(ti.extension.sparse) @ti.all_archs def test_chain_compare():