Skip to content

Commit

Permalink
[bug] [lang] Enable break in the outermost for not in the outermost s…
Browse files Browse the repository at this point in the history
…cope (#4447)
  • Loading branch information
lin-hitonami authored Mar 4, 2022
1 parent 0fe0db2 commit 58d3417
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
12 changes: 6 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,19 +909,19 @@ void ASTBuilder::insert_expr_stmt(const Expr &val) {

void ASTBuilder::create_scope(std::unique_ptr<Block> &list, LoopType tp) {
TI_ASSERT(list == nullptr);
list = std::make_unique<Block>();
if (!stack_.empty()) {
list->parent_stmt = get_last_stmt();
}
stack_.push_back(list.get());
LoopState prev = loop_state_stack_.back();
if (tp == NotLoop) {
loop_state_stack_.push_back(prev);
} else if (tp == For && prev == None) {
} else if (tp == For && stack_.size() == 1) {
loop_state_stack_.push_back(Outermost);
} else {
loop_state_stack_.push_back(Inner);
}
list = std::make_unique<Block>();
if (!stack_.empty()) {
list->parent_stmt = get_last_stmt();
}
stack_.push_back(list.get());
}

void ASTBuilder::pop_scope() {
Expand Down
10 changes: 7 additions & 3 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LowerAST : public IRVisitor {
Stmt *capturing_loop_;
std::unordered_set<Stmt *> detected_fors_with_break_;
Block *current_block_;
int current_block_depth_;

FlattenContext make_flatten_ctx() {
FlattenContext fctx;
Expand All @@ -43,7 +44,8 @@ class LowerAST : public IRVisitor {
public:
explicit LowerAST(const std::unordered_set<Stmt *> &_detected_fors_with_break)
: detected_fors_with_break_(_detected_fors_with_break),
current_block_(nullptr) {
current_block_(nullptr),
current_block_depth_(0) {
// TODO: change this to false
allow_undefined_visitor = true;
capturing_loop_ = nullptr;
Expand All @@ -53,9 +55,11 @@ class LowerAST : public IRVisitor {
auto backup_block = this->current_block_;
this->current_block_ = stmt_list;
auto stmts = make_raw_pointer_list(stmt_list->statements);
current_block_depth_++;
for (auto &stmt : stmts) {
stmt->accept(this);
}
current_block_depth_--;
this->current_block_ = backup_block;
}

Expand Down Expand Up @@ -201,8 +205,8 @@ class LowerAST : public IRVisitor {
flatten_rvalue(begin, &fctx);
flatten_rvalue(end, &fctx);
bool is_good_range_for =
capturing_loop_ == nullptr || detected_fors_with_break_.find(stmt) ==
detected_fors_with_break_.end();
current_block_depth_ == 1 || detected_fors_with_break_.find(stmt) ==
detected_fors_with_break_.end();
// #578: a good range for is a range for that doesn't contains a break
// statement
if (is_good_range_for) {
Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,18 @@ def func():
x[None] = 1
func()
assert x[None] == 1


@test_utils.test()
def test_break_in_outermost_for_not_in_outermost_scope():
@ti.kernel
def foo() -> ti.i32:
a = 0
if True:
for i in range(1000):
if i == 100:
break
a += 1
return a

assert foo() == 100

0 comments on commit 58d3417

Please sign in to comment.