diff --git a/python/taichi/lang/ir_builder.py b/python/taichi/lang/ir_builder.py index b1ea38b227a0b..fae8d1427edba 100644 --- a/python/taichi/lang/ir_builder.py +++ b/python/taichi/lang/ir_builder.py @@ -30,12 +30,8 @@ def build_Assign(ctx, node): # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. for node_target in node.targets: - if isinstance(node_target, ast.Tuple): - IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, - is_static_assign) - else: - IRBuilder.build_assign_basic(ctx, node_target, node.value.ptr, - is_static_assign) + IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, + is_static_assign) return node @staticmethod @@ -50,7 +46,9 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): values: A node/list representing the values. is_static_assign: A boolean value indicating whether this is a static assignment """ - + if not isinstance(node_target, ast.Tuple): + return IRBuilder.build_assign_basic(ctx, node_target, values, + is_static_assign) targets = node_target.elts tmp_tuple = values if is_static_assign else ti.expr_init_list( values, len(targets)) @@ -111,6 +109,64 @@ def build_List(ctx, node): node.ptr = [elt.ptr for elt in node.elts] return node + @staticmethod + def process_listcomp(ctx, node, result): + result.append(build_stmt(ctx, node.elt).ptr) + + @staticmethod + def process_dictcomp(ctx, node, result): + key = build_stmt(ctx, node.key).ptr + value = build_stmt(ctx, node.value).ptr + result[key] = value + + @staticmethod + def process_generators(ctx, node, now_comp, func, result): + if now_comp >= len(node.generators): + return func(ctx, node, result) + target = node.generators[now_comp].target = build_stmt( + ctx, node.generators[now_comp].target) + iter = node.generators[now_comp].iter = build_stmt( + ctx, node.generators[now_comp].iter) + for value in iter.ptr: + with ctx.variable_scope_guard(): + IRBuilder.build_assign_unpack(ctx, target, value, True) + node.generators[now_comp].ifs = build_stmts( + ctx, node.generators[now_comp].ifs) + IRBuilder.process_ifs(ctx, node, now_comp, 0, func, result) + + @staticmethod + def process_ifs(ctx, node, now_comp, now_if, func, result): + if now_if >= len(node.generators[now_comp].ifs): + return IRBuilder.process_generators(ctx, node, now_comp + 1, func, + result) + cond = node.generators[now_comp].ifs[now_if].ptr + if cond: + IRBuilder.process_ifs(ctx, node, now_comp, now_if + 1, func, + result) + + @staticmethod + def build_comprehension(ctx, node): + node.target = build_stmt(ctx, node.target) + node.iter = build_stmt(ctx, node.iter) + node.ifs = build_stmts(ctx, node.ifs) + return node + + @staticmethod + def build_ListComp(ctx, node): + result = [] + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_listcomp, + result) + node.ptr = result + return node + + @staticmethod + def build_DictComp(ctx, node): + result = {} + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_dictcomp, + result) + node.ptr = result + return node + @staticmethod def build_Index(ctx, node): node.value = build_stmt(ctx, node.value) diff --git a/tests/python/test_ast_refactor.py b/tests/python/test_ast_refactor.py index d3e71338b2ca0..1cd2839432535 100644 --- a/tests/python/test_ast_refactor.py +++ b/tests/python/test_ast_refactor.py @@ -754,6 +754,56 @@ def foo(x: tc.template()) -> tc.i32: assert foo(i) == fib[i] +@ti.test(experimental_ast_refactor=True) +def test_listcomp(): + @ti.func + def identity(dt, n: ti.template()): + return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)] + for i in range(n)], + disable_local_tensor=1) + + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = identity(ti.i32, n) + b = [j for i in a for j in i] + ret = 0 + for i in ti.static(range(n)): + for j in ti.static(range(n)): + ret += i * j * b[i * n + j] + return ret + + assert foo(5) == 1 + 4 + 9 + 16 + + +@ti.test(experimental_ast_refactor=True) +def test_dictcomp(): + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + ret = 0 + for i in ti.static(range(n)): + if ti.static(i % 3): + if ti.static(i % 2): + ret += a[i] + return ret + + assert foo(10) == 1 * 1 + 5 * 5 + 7 * 7 + + +@ti.test(experimental_ast_refactor=True) +def test_dictcomp_fail(): + @ti.kernel + def foo(n: ti.template(), m: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + return a[m] + + with pytest.raises(KeyError): + foo(5, 2) + + with pytest.raises(KeyError): + foo(5, 3) + + @pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') @ti.test(exclude=ti.opengl, experimental_ast_refactor=True) def test_ndarray():