From 824c94b3416510a01655afd91fe4f9c7355084f8 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 22 Jun 2020 17:22:53 -0400 Subject: [PATCH] [Bug] [opt] Fix compilation crash when there is a container statement after an unconditional continue (#1299) * [Bug] [opt] Fix compilation crash when there is a container statement after an unconditional continue * [skip ci] enforce code format * retrigger CI * retrigger CI Co-authored-by: Taichi Gardener --- taichi/ir/control_flow_graph.h | 1 + taichi/ir/ir.cpp | 2 +- taichi/transforms/cfg_optimization.cpp | 2 - taichi/transforms/ir_printer.cpp | 2 +- taichi/transforms/simplify.cpp | 10 +++- tests/python/test_continue.py | 64 ++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 5 deletions(-) diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 8b7969c7dd299..1c88bbaa928e4 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -65,6 +65,7 @@ class ControlFlowGraph { void reaching_definition_analysis(bool after_lower_access); void simplify_graph(); + // This pass cannot eliminate container statements properly for now. bool unreachable_code_elimination(); bool store_to_load_forwarding(bool after_lower_access); }; diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 18bdc6e531a0a..097e466adf44a 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -671,7 +671,7 @@ Stmt *Block::insert(VecStatement &&stmt, int location) { stmt_ptr = stmt.back().get(); } if (location == -1) { - location = (int)statements.size() - 1; + location = (int)statements.size(); } for (int i = 0; i < stmt.size(); i++) { insert(std::move(stmt[i]), location + i); diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 9943ec860836a..42ad39fb4b77d 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -12,8 +12,6 @@ void cfg_optimization(IRNode *root, bool after_lower_access) { while (true) { bool modified = false; cfg->simplify_graph(); - if (cfg->unreachable_code_elimination()) - modified = true; if (cfg->store_to_load_forwarding(after_lower_access)) modified = true; if (!modified) diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index d98c2ff1cc9c7..793e9c6ced020 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -250,7 +250,7 @@ class IRPrinter : public IRVisitor { void visit(ContinueStmt *stmt) override { if (stmt->scope) { - print("{} continue (scope={})", stmt->name(), stmt->name()); + print("{} continue (scope={})", stmt->name(), stmt->scope->name()); } else { print("{} continue", stmt->name()); } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 079db6664e15b..c596dd7148ddd 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -958,7 +958,13 @@ class BasicBlockSimplify : public IRVisitor { } void visit(ContinueStmt *stmt) override { - return; + if (stmt != stmt->parent->back()) { + const int location = stmt->parent->locate(stmt); + while (location + 1 < (int)stmt->parent->size()) { + stmt->parent->erase(location + 1); + } + throw IRModified(); + } } static bool is_global_write(Stmt *stmt) { @@ -1205,6 +1211,8 @@ bool simplify(IRNode *root, Kernel *kernel) { else break; } + if (modified) + fix_block_parents(root); return modified; } diff --git a/tests/python/test_continue.py b/tests/python/test_continue.py index ae533ed1e4e6d..6f82035e8e912 100644 --- a/tests/python/test_continue.py +++ b/tests/python/test_continue.py @@ -82,3 +82,67 @@ def run(): xs = x.to_numpy() for i in range(n): assert xs[i] == 0 + + +@ti.all_archs +def test_kernel_continue_in_nested_if(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(a: ti.i32): + for i in range(1): + if a: + if a: + continue + if a: + if a: + continue + x[i] = i + + x[0] = 1 + run(1) + assert x[0] == 1 + run(0) + assert x[0] == 0 + + +@ti.all_archs +def test_kernel_continue_in_nested_if_2(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(a: ti.i32): + for i in range(1): + if a: + if a: + continue + if a: + continue + x[i] = i + + x[0] = 1 + run(1) + assert x[0] == 1 + run(0) + assert x[0] == 0 + + +@ti.all_archs +def test_kernel_continue_in_nested_if_3(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(a: ti.i32): + for i in range(1): + if a: + continue + if a: + if a: + continue + x[i] = i + + x[0] = 1 + run(1) + assert x[0] == 1 + run(0) + assert x[0] == 0