Skip to content

Commit

Permalink
[Bug] [opt] Fix compilation crash when there is a container statement…
Browse files Browse the repository at this point in the history
… after an unconditional continue (#1299)

* [Bug] [opt] Fix compilation crash when there is a container statement after an unconditional continue

* [skip ci] enforce code format

* retrigger CI

* retrigger CI

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
xumingkuan and taichi-gardener authored Jun 22, 2020
1 parent 4c25993 commit 824c94b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 5 deletions.
1 change: 1 addition & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ControlFlowGraph {
void reaching_definition_analysis(bool after_lower_access);

void simplify_graph();
// This pass cannot eliminate container statements properly for now.
bool unreachable_code_elimination();
bool store_to_load_forwarding(bool after_lower_access);
};
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ Stmt *Block::insert(VecStatement &&stmt, int location) {
stmt_ptr = stmt.back().get();
}
if (location == -1) {
location = (int)statements.size() - 1;
location = (int)statements.size();
}
for (int i = 0; i < stmt.size(); i++) {
insert(std::move(stmt[i]), location + i);
Expand Down
2 changes: 0 additions & 2 deletions taichi/transforms/cfg_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ void cfg_optimization(IRNode *root, bool after_lower_access) {
while (true) {
bool modified = false;
cfg->simplify_graph();
if (cfg->unreachable_code_elimination())
modified = true;
if (cfg->store_to_load_forwarding(after_lower_access))
modified = true;
if (!modified)
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class IRPrinter : public IRVisitor {

void visit(ContinueStmt *stmt) override {
if (stmt->scope) {
print("{} continue (scope={})", stmt->name(), stmt->name());
print("{} continue (scope={})", stmt->name(), stmt->scope->name());
} else {
print("{} continue", stmt->name());
}
Expand Down
10 changes: 9 additions & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,13 @@ class BasicBlockSimplify : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
return;
if (stmt != stmt->parent->back()) {
const int location = stmt->parent->locate(stmt);
while (location + 1 < (int)stmt->parent->size()) {
stmt->parent->erase(location + 1);
}
throw IRModified();
}
}

static bool is_global_write(Stmt *stmt) {
Expand Down Expand Up @@ -1205,6 +1211,8 @@ bool simplify(IRNode *root, Kernel *kernel) {
else
break;
}
if (modified)
fix_block_parents(root);
return modified;
}

Expand Down
64 changes: 64 additions & 0 deletions tests/python/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,67 @@ def run():
xs = x.to_numpy()
for i in range(n):
assert xs[i] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
if a:
continue
if a:
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if_2():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
if a:
continue
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if_3():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
continue
if a:
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0

0 comments on commit 824c94b

Please sign in to comment.