From 055d9b6ba84848fe964d0628071239b32ca0dfb5 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 5 Nov 2021 19:03:42 +0800 Subject: [PATCH 1/3] add ListComp and DictComp --- python/taichi/lang/ir_builder.py | 72 ++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/python/taichi/lang/ir_builder.py b/python/taichi/lang/ir_builder.py index 9e8a13b769369..b136ebd7c4c8b 100644 --- a/python/taichi/lang/ir_builder.py +++ b/python/taichi/lang/ir_builder.py @@ -29,12 +29,7 @@ def build_Assign(ctx, node): # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. for node_target in node.targets: - if isinstance(node_target, ast.Tuple): - IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, - is_static_assign) - else: - IRBuilder.build_assign_basic(ctx, node_target, node.value.ptr, - is_static_assign) + IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, is_static_assign) return node @staticmethod @@ -49,7 +44,9 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): values: A node/list representing the values. is_static_assign: A boolean value indicating whether this is a static assignment """ - + if not isinstance(node_target, ast.Tuple): + IRBuilder.build_assign_basic(ctx, node_target, values, is_static_assign) + return targets = node_target.elts tmp_tuple = values if is_static_assign else ti.expr_init_list( values, len(targets)) @@ -102,6 +99,67 @@ def build_List(ctx, node): node.ptr = [elt.ptr for elt in node.elts] return node + @staticmethod + def process_listcomp(ctx, node, result): + result.append(build_stmt(ctx, node.elt).ptr) + + @staticmethod + def process_dictcomp(ctx, node, result): + key = build_stmt(ctx, node.key).ptr + value = build_stmt(ctx, node.value).ptr + result[key] = value + + @staticmethod + def process_generators(ctx, node, now_comp, func, result): + if now_comp >= len(node.generators): + return func(ctx, node, result) + comp = node.generators[now_comp] = build_stmt(ctx, node.generators[now_comp]) + for value in comp.iter.ptr: + with ctx.variable_scope_guard: + IRBuilder.build_assign_unpack(ctx, comp.target, value, True) + IRBuilder.process_ifs(ctx, node, now_comp, 0, func, result) + + @staticmethod + def process_ifs(ctx, node, now_comp, now_if, func, result): + if now_if >= len(node.generators[now_comp].ifs): + return IRBuilder.process_generators(ctx, node, now_comp + 1, func, result) + cond = node.generators[now_comp].ifs[now_if].ptr + ti.begin_frontend_if(cond) + ti.core.begin_frontend_if_true() + IRBuilder.process_ifs(ctx, node, now_comp, now_if + 1, func, result) + ti.core.pop_scope() + ti.core.begin_frontend_if_false() + ti.core.pop_scope() + + @staticmethod + def build_comprehension(ctx, node): + node.target = build_stmt(ctx, node.target) + node.iter = build_stmt(ctx, node.iter) + node.ifs = build_stmts(ctx, node.ifs) + return node + + @staticmethod + def build_ListComp(ctx, node): + result = [] + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_listcomp, result) + node.ptr = result + return node + + @staticmethod + def build_DictComp(ctx, node): + result = [] + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_dictcomp, result) + node.ptr = result + return node +''' +comp0=node.generators[0] +for tar0 in comp0.iter: + ctx.create_variable(comp0.target, tar0) + for tar1 in comp1.iter: + ctx.create_variable(comp1.target, tar1) + +''' + @staticmethod def build_Index(ctx, node): node.value = build_stmt(ctx, node.value) From 2df36a253afbd7c9545d51556a8e6b6aebe463a7 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 5 Nov 2021 19:07:56 +0800 Subject: [PATCH 2/3] fix --- python/taichi/lang/ir_builder.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/python/taichi/lang/ir_builder.py b/python/taichi/lang/ir_builder.py index b136ebd7c4c8b..9ca3cf333e829 100644 --- a/python/taichi/lang/ir_builder.py +++ b/python/taichi/lang/ir_builder.py @@ -29,7 +29,8 @@ def build_Assign(ctx, node): # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. for node_target in node.targets: - IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, is_static_assign) + IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, + is_static_assign) return node @staticmethod @@ -45,8 +46,8 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): is_static_assign: A boolean value indicating whether this is a static assignment """ if not isinstance(node_target, ast.Tuple): - IRBuilder.build_assign_basic(ctx, node_target, values, is_static_assign) - return + return IRBuilder.build_assign_basic(ctx, node_target, values, + is_static_assign) targets = node_target.elts tmp_tuple = values if is_static_assign else ti.expr_init_list( values, len(targets)) @@ -113,7 +114,8 @@ def process_dictcomp(ctx, node, result): def process_generators(ctx, node, now_comp, func, result): if now_comp >= len(node.generators): return func(ctx, node, result) - comp = node.generators[now_comp] = build_stmt(ctx, node.generators[now_comp]) + comp = node.generators[now_comp] = build_stmt( + ctx, node.generators[now_comp]) for value in comp.iter.ptr: with ctx.variable_scope_guard: IRBuilder.build_assign_unpack(ctx, comp.target, value, True) @@ -122,7 +124,8 @@ def process_generators(ctx, node, now_comp, func, result): @staticmethod def process_ifs(ctx, node, now_comp, now_if, func, result): if now_if >= len(node.generators[now_comp].ifs): - return IRBuilder.process_generators(ctx, node, now_comp + 1, func, result) + return IRBuilder.process_generators(ctx, node, now_comp + 1, func, + result) cond = node.generators[now_comp].ifs[now_if].ptr ti.begin_frontend_if(cond) ti.core.begin_frontend_if_true() @@ -141,24 +144,18 @@ def build_comprehension(ctx, node): @staticmethod def build_ListComp(ctx, node): result = [] - IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_listcomp, result) + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_listcomp, + result) node.ptr = result return node @staticmethod def build_DictComp(ctx, node): - result = [] - IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_dictcomp, result) + result = {} + IRBuilder.process_generators(ctx, node, 0, IRBuilder.process_dictcomp, + result) node.ptr = result return node -''' -comp0=node.generators[0] -for tar0 in comp0.iter: - ctx.create_variable(comp0.target, tar0) - for tar1 in comp1.iter: - ctx.create_variable(comp1.target, tar1) - -''' @staticmethod def build_Index(ctx, node): From dea72d8c24e2772756db5ef5664add65df790c62 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 8 Nov 2021 15:52:27 +0800 Subject: [PATCH 3/3] add test --- python/taichi/lang/ir_builder.py | 23 +++++++------- tests/python/test_ast_refactor.py | 50 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/python/taichi/lang/ir_builder.py b/python/taichi/lang/ir_builder.py index 9ca3cf333e829..7de0bf32f7c1d 100644 --- a/python/taichi/lang/ir_builder.py +++ b/python/taichi/lang/ir_builder.py @@ -114,11 +114,15 @@ def process_dictcomp(ctx, node, result): def process_generators(ctx, node, now_comp, func, result): if now_comp >= len(node.generators): return func(ctx, node, result) - comp = node.generators[now_comp] = build_stmt( - ctx, node.generators[now_comp]) - for value in comp.iter.ptr: - with ctx.variable_scope_guard: - IRBuilder.build_assign_unpack(ctx, comp.target, value, True) + target = node.generators[now_comp].target = build_stmt( + ctx, node.generators[now_comp].target) + iter = node.generators[now_comp].iter = build_stmt( + ctx, node.generators[now_comp].iter) + for value in iter.ptr: + with ctx.variable_scope_guard(): + IRBuilder.build_assign_unpack(ctx, target, value, True) + node.generators[now_comp].ifs = build_stmts( + ctx, node.generators[now_comp].ifs) IRBuilder.process_ifs(ctx, node, now_comp, 0, func, result) @staticmethod @@ -127,12 +131,9 @@ def process_ifs(ctx, node, now_comp, now_if, func, result): return IRBuilder.process_generators(ctx, node, now_comp + 1, func, result) cond = node.generators[now_comp].ifs[now_if].ptr - ti.begin_frontend_if(cond) - ti.core.begin_frontend_if_true() - IRBuilder.process_ifs(ctx, node, now_comp, now_if + 1, func, result) - ti.core.pop_scope() - ti.core.begin_frontend_if_false() - ti.core.pop_scope() + if cond: + IRBuilder.process_ifs(ctx, node, now_comp, now_if + 1, func, + result) @staticmethod def build_comprehension(ctx, node): diff --git a/tests/python/test_ast_refactor.py b/tests/python/test_ast_refactor.py index 52e8a5327f73c..c80203629623c 100644 --- a/tests/python/test_ast_refactor.py +++ b/tests/python/test_ast_refactor.py @@ -415,3 +415,53 @@ def foo(x: np.template()) -> np.i32: for i in range(10): assert foo(i) == fib[i] + + +@ti.test(experimental_ast_refactor=True) +def test_listcomp(): + @ti.func + def identity(dt, n: ti.template()): + return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)] + for i in range(n)], + disable_local_tensor=1) + + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = identity(ti.i32, n) + b = [j for i in a for j in i] + ret = 0 + for i in ti.static(range(n)): + for j in ti.static(range(n)): + ret += i * j * b[i * n + j] + return ret + + assert foo(5) == 1 + 4 + 9 + 16 + + +@ti.test(experimental_ast_refactor=True) +def test_dictcomp(): + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + ret = 0 + for i in ti.static(range(n)): + if ti.static(i % 3): + if ti.static(i % 2): + ret += a[i] + return ret + + assert foo(10) == 1 * 1 + 5 * 5 + 7 * 7 + + +@ti.test(experimental_ast_refactor=True) +def test_dictcomp_fail(): + @ti.kernel + def foo(n: ti.template(), m: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + return a[m] + + with pytest.raises(KeyError): + foo(5, 2) + + with pytest.raises(KeyError): + foo(5, 3)