Skip to content

Commit

Permalink
add while
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami committed Nov 5, 2021
1 parent 97e051f commit 7e33579
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
24 changes: 24 additions & 0 deletions python/taichi/lang/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def build_Bytes(ctx, node):
node.ptr = node.s
return node

@staticmethod
def build_NameConstant(ctx, node):
node.ptr = node.value
return node

@staticmethod
def build_keyword(ctx, node):
node.value = build_stmt(ctx, node.value)
Expand Down Expand Up @@ -638,6 +643,25 @@ def build_For(ctx, node):
else: # Struct for
return IRBuilder.build_struct_for(ctx, node, is_grouped=False)

@staticmethod
def build_While(ctx, node):
if node.orelse:
raise TaichiSyntaxError(
"'else' clause for 'while' not supported in Taichi kernels")

with ctx.control_scope_guard():
ti.core.begin_frontend_while(ti.Expr(1).ptr)
while_cond = build_stmt(ctx, node.test).ptr
ti.begin_frontend_if(while_cond)
ti.core.begin_frontend_if_true()
ti.core.pop_scope()
ti.core.begin_frontend_if_false()
ti.core.insert_break_stmt()
ti.core.pop_scope()
node.body = build_stmts(ctx, node.body)
ti.core.pop_scope()
return node

@staticmethod
def build_If(ctx, node):
node.test = build_stmt(ctx, node.test)
Expand Down
40 changes: 40 additions & 0 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,46 @@ def foo(a: ti.template()):
assert a[i, j] == 3


@ti.test(experimental_ast_refactor=True)
def test_while():
x = ti.field(ti.f32)

N = 1

ti.root.dense(ti.i, N).place(x)

@ti.kernel
def func():
i = 0
s = 0
while i < 10:
s += i
i += 1
x[0] = s

func()
assert x[0] == 45


@ti.test(experimental_ast_refactor=True)
def test_while_break():
ret = ti.field(ti.i32, shape=())

@ti.kernel
def func():
i = 0
s = 0
while True:
s += i
i += 1
if i > 10:
break
ret[None] = s

func()
assert ret[None] == 55


@ti.test(experimental_ast_refactor=True, print_preprocessed_ir=True)
def test_func():
@ti.func
Expand Down

0 comments on commit 7e33579

Please sign in to comment.