diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 7e8824be2203f..20d8f16c4731b 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -680,7 +680,16 @@ class KernelCodegenImpl : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt->as_return()) { + auto stmt_in_off_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; + }; + if (stmt_in_off_for()) { emit("return;"); } else { emit("continue;"); diff --git a/taichi/backends/vulkan/codegen_vulkan.cpp b/taichi/backends/vulkan/codegen_vulkan.cpp index 7dca7ddb13407..b5ca87985c82f 100644 --- a/taichi/backends/vulkan/codegen_vulkan.cpp +++ b/taichi/backends/vulkan/codegen_vulkan.cpp @@ -1013,7 +1013,16 @@ class TaskCodegen : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt->as_return()) { + auto stmt_in_off_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; + }; + if (stmt_in_off_for()) { // Return means end THIS main loop and start next loop, not exit kernel ir_->make_inst(spv::OpBranch, return_label()); } else { diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 68446a8ba418c..e7b42f30159dd 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -875,7 +875,16 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) { void CodeGenLLVM::visit(ContinueStmt *stmt) { using namespace llvm; - if (stmt->as_return()) { + auto stmt_in_off_range_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return offl->task_type == OffloadedStmt::TaskType::range_for; + } + return false; + }; + if (stmt_in_off_range_for()) { builder->CreateRetVoid(); } else { TI_ASSERT(current_loop_reentry != nullptr); @@ -1743,6 +1752,9 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { auto struct_for_body_bb = BasicBlock::Create(*llvm_context, "struct_for_body_body", func); + auto lrg = make_loop_reentry_guard(this); + current_loop_reentry = body_tail_bb; + builder->CreateBr(loop_test_bb); { diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 924f05cea93d4..a1b86cd2763b8 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -4,16 +4,6 @@ TLANG_NAMESPACE_BEGIN -bool ContinueStmt::as_return() const { - TI_ASSERT(scope != nullptr); - if (auto *offl = scope->cast(); offl) { - TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || - offl->task_type == OffloadedStmt::TaskType::struct_for); - return true; - } - return false; -} - UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) : op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 8dd9dc1dab445..c043582a4d379 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -91,7 +91,6 @@ class ContinueStmt : public Stmt { // // If run_foo_kernel() is directly inlined within foo_kernel(), `return` // could prematurely terminate the entire kernel. - bool as_return() const; TI_STMT_DEF_FIELDS(scope); TI_DEFINE_ACCEPT_AND_CLONE; diff --git a/tests/python/test_struct_for.py b/tests/python/test_struct_for.py index 60825eff1b37c..e2327d920140a 100644 --- a/tests/python/test_struct_for.py +++ b/tests/python/test_struct_for.py @@ -279,3 +279,36 @@ def count() -> int: return tot assert count() == 28 + + +@ti.test(require=ti.extension.sparse) +def test_struct_for_continue(): + # Related issue: https://github.com/taichi-dev/taichi/issues/3272 + x = ti.field(dtype=ti.i32) + n = 4 + ti.root.pointer(ti.i, n).dense(ti.i, n).place(x) + + @ti.kernel + def init(): + for i in range(n): + x[i * n + i] = 1 + + @ti.kernel + def struct_for_continue() -> ti.i32: + cnt = 0 + for i in x: + if x[i]: continue + cnt += 1 + return cnt + + @ti.kernel + def range_for_continue() -> ti.i32: + cnt = 0 + for i in range(n * n): + if x[i]: continue + cnt += 1 + return cnt + + init() + assert struct_for_continue() == n * (n - 1) + assert range_for_continue() == n * (n - 1)