From 58d34173ec4b0f634ce3fb07e07319c72512efcf Mon Sep 17 00:00:00 2001 From: Lin Jiang <90667349+lin-hitonami@users.noreply.github.com> Date: Fri, 4 Mar 2022 16:45:28 +0800 Subject: [PATCH] [bug] [lang] Enable break in the outermost for not in the outermost scope (#4447) --- taichi/ir/frontend_ir.cpp | 12 ++++++------ taichi/transforms/lower_ast.cpp | 10 +++++++--- tests/python/test_loops.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 2e18ad7c76b05..cbbf7c089d0b2 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -909,19 +909,19 @@ void ASTBuilder::insert_expr_stmt(const Expr &val) { void ASTBuilder::create_scope(std::unique_ptr &list, LoopType tp) { TI_ASSERT(list == nullptr); - list = std::make_unique(); - if (!stack_.empty()) { - list->parent_stmt = get_last_stmt(); - } - stack_.push_back(list.get()); LoopState prev = loop_state_stack_.back(); if (tp == NotLoop) { loop_state_stack_.push_back(prev); - } else if (tp == For && prev == None) { + } else if (tp == For && stack_.size() == 1) { loop_state_stack_.push_back(Outermost); } else { loop_state_stack_.push_back(Inner); } + list = std::make_unique(); + if (!stack_.empty()) { + list->parent_stmt = get_last_stmt(); + } + stack_.push_back(list.get()); } void ASTBuilder::pop_scope() { diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 88995f468a217..81daee000c54d 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -33,6 +33,7 @@ class LowerAST : public IRVisitor { Stmt *capturing_loop_; std::unordered_set detected_fors_with_break_; Block *current_block_; + int current_block_depth_; FlattenContext make_flatten_ctx() { FlattenContext fctx; @@ -43,7 +44,8 @@ class LowerAST : public IRVisitor { public: explicit LowerAST(const std::unordered_set &_detected_fors_with_break) : detected_fors_with_break_(_detected_fors_with_break), - current_block_(nullptr) { + current_block_(nullptr), + current_block_depth_(0) { // TODO: change this to false allow_undefined_visitor = true; capturing_loop_ = nullptr; @@ -53,9 +55,11 @@ class LowerAST : public IRVisitor { auto backup_block = this->current_block_; this->current_block_ = stmt_list; auto stmts = make_raw_pointer_list(stmt_list->statements); + current_block_depth_++; for (auto &stmt : stmts) { stmt->accept(this); } + current_block_depth_--; this->current_block_ = backup_block; } @@ -201,8 +205,8 @@ class LowerAST : public IRVisitor { flatten_rvalue(begin, &fctx); flatten_rvalue(end, &fctx); bool is_good_range_for = - capturing_loop_ == nullptr || detected_fors_with_break_.find(stmt) == - detected_fors_with_break_.end(); + current_block_depth_ == 1 || detected_fors_with_break_.find(stmt) == + detected_fors_with_break_.end(); // #578: a good range for is a range for that doesn't contains a break // statement if (is_good_range_for) { diff --git a/tests/python/test_loops.py b/tests/python/test_loops.py index 588bf302d34ca..81b0456df8cd3 100644 --- a/tests/python/test_loops.py +++ b/tests/python/test_loops.py @@ -172,3 +172,18 @@ def func(): x[None] = 1 func() assert x[None] == 1 + + +@test_utils.test() +def test_break_in_outermost_for_not_in_outermost_scope(): + @ti.kernel + def foo() -> ti.i32: + a = 0 + if True: + for i in range(1000): + if i == 100: + break + a += 1 + return a + + assert foo() == 100