Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [opt] Fix compilation crash when there is a container statement after an unconditional continue #1299

Merged
merged 4 commits into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ControlFlowGraph {
void reaching_definition_analysis(bool after_lower_access);

void simplify_graph();
// This pass cannot eliminate container statements properly for now.
bool unreachable_code_elimination();
bool store_to_load_forwarding(bool after_lower_access);
};
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ void Block::insert(std::unique_ptr<Stmt> &&stmt, int location) {

void Block::insert(VecStatement &&stmt, int location) {
if (location == -1) {
location = (int)statements.size() - 1;
location = (int)statements.size();
Copy link
Contributor Author

@xumingkuan xumingkuan Jun 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I wonder why this has been in the codebase for 5 months...)

So actually it should be

            if a:
                continue
            if a:
                if a:
                    continue

instead of

            if a:
                if a:
                    continue
            if a:
                continue

which causes ControlFlowGraph::unreachable_code_elimination to produce malformed IR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! It's surprising that we have this issue for so long time without noticing...

}
for (int i = 0; i < stmt.size(); i++) {
insert(std::move(stmt[i]), location + i);
Expand Down
2 changes: 0 additions & 2 deletions taichi/transforms/cfg_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ void cfg_optimization(IRNode *root, bool after_lower_access) {
while (true) {
bool modified = false;
cfg->simplify_graph();
if (cfg->unreachable_code_elimination())
modified = true;
if (cfg->store_to_load_forwarding(after_lower_access))
modified = true;
if (!modified)
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ class IRPrinter : public IRVisitor {

void visit(ContinueStmt *stmt) override {
if (stmt->scope) {
print("{} continue (scope={})", stmt->name(), stmt->name());
print("{} continue (scope={})", stmt->name(),
stmt->scope->name());
} else {
print("{} continue", stmt->name());
}
Expand Down
10 changes: 9 additions & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,13 @@ class BasicBlockSimplify : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
return;
if (stmt != stmt->parent->back()) {
const int location = stmt->parent->locate(stmt);
while (location + 1 < (int)stmt->parent->size()) {
stmt->parent->erase(location + 1);
}
throw IRModified();
}
}

static bool is_global_write(Stmt *stmt) {
Expand Down Expand Up @@ -1201,6 +1207,8 @@ bool simplify(IRNode *root, Kernel *kernel) {
else
break;
}
if (modified)
fix_block_parents(root);
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
return modified;
}

Expand Down
64 changes: 64 additions & 0 deletions tests/python/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,67 @@ def run():
xs = x.to_numpy()
for i in range(n):
assert xs[i] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
if a:
continue
if a:
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if_2():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
if a:
continue
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0


@ti.all_archs
def test_kernel_continue_in_nested_if_3():
x = ti.var(ti.i32, shape=n)

@ti.kernel
def run(a: ti.i32):
for i in range(1):
if a:
continue
if a:
if a:
continue
x[i] = i

x[0] = 1
run(1)
assert x[0] == 1
run(0)
assert x[0] == 0