Skip to content

Commit

Permalink
[IR] Fix continue statement in struct for and add a related test (#3282)
Browse files Browse the repository at this point in the history
* fix continue in struct for and add a test

* Auto Format

* specify changes only for llvm backend

* Auto Format

* remove as_return in statements

* Auto Format

* Update taichi/backends/vulkan/codegen_vulkan.cpp

Co-authored-by: Ye Kuang <[email protected]>

* Update taichi/backends/metal/codegen_metal.cpp

Co-authored-by: Ye Kuang <[email protected]>

* Update taichi/codegen/codegen_llvm.cpp

Co-authored-by: Ye Kuang <[email protected]>

* Update codegen_metal.cpp

* Update codegen_vulkan.cpp

* Update codegen_llvm.cpp

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Chang Yu <[email protected]>
Co-authored-by: Ye Kuang <[email protected]>
  • Loading branch information
4 people authored Nov 6, 2021
1 parent 445539e commit d4af2f8
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 14 deletions.
11 changes: 10 additions & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,16 @@ class KernelCodegenImpl : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
if (stmt->as_return()) {
auto stmt_in_off_for = [stmt]() {
TI_ASSERT(stmt->scope != nullptr);
if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return true;
}
return false;
};
if (stmt_in_off_for()) {
emit("return;");
} else {
emit("continue;");
Expand Down
11 changes: 10 additions & 1 deletion taichi/backends/vulkan/codegen_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,16 @@ class TaskCodegen : public IRVisitor {
}

void visit(ContinueStmt *stmt) override {
if (stmt->as_return()) {
auto stmt_in_off_for = [stmt]() {
TI_ASSERT(stmt->scope != nullptr);
if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return true;
}
return false;
};
if (stmt_in_off_for()) {
// Return means end THIS main loop and start next loop, not exit kernel
ir_->make_inst(spv::OpBranch, return_label());
} else {
Expand Down
14 changes: 13 additions & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,16 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) {

void CodeGenLLVM::visit(ContinueStmt *stmt) {
using namespace llvm;
if (stmt->as_return()) {
auto stmt_in_off_range_for = [stmt]() {
TI_ASSERT(stmt->scope != nullptr);
if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return offl->task_type == OffloadedStmt::TaskType::range_for;
}
return false;
};
if (stmt_in_off_range_for()) {
builder->CreateRetVoid();
} else {
TI_ASSERT(current_loop_reentry != nullptr);
Expand Down Expand Up @@ -1743,6 +1752,9 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
auto struct_for_body_bb =
BasicBlock::Create(*llvm_context, "struct_for_body_body", func);

auto lrg = make_loop_reentry_guard(this);
current_loop_reentry = body_tail_bb;

builder->CreateBr(loop_test_bb);

{
Expand Down
10 changes: 0 additions & 10 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@

TLANG_NAMESPACE_BEGIN

bool ContinueStmt::as_return() const {
TI_ASSERT(scope != nullptr);
if (auto *offl = scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return true;
}
return false;
}

UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand)
: op_type(op_type), operand(operand) {
TI_ASSERT(!operand->is<AllocaStmt>());
Expand Down
1 change: 0 additions & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class ContinueStmt : public Stmt {
//
// If run_foo_kernel() is directly inlined within foo_kernel(), `return`
// could prematurely terminate the entire kernel.
bool as_return() const;

TI_STMT_DEF_FIELDS(scope);
TI_DEFINE_ACCEPT_AND_CLONE;
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_struct_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,36 @@ def count() -> int:
return tot

assert count() == 28


@ti.test(require=ti.extension.sparse)
def test_struct_for_continue():
# Related issue: https://github.com/taichi-dev/taichi/issues/3272
x = ti.field(dtype=ti.i32)
n = 4
ti.root.pointer(ti.i, n).dense(ti.i, n).place(x)

@ti.kernel
def init():
for i in range(n):
x[i * n + i] = 1

@ti.kernel
def struct_for_continue() -> ti.i32:
cnt = 0
for i in x:
if x[i]: continue
cnt += 1
return cnt

@ti.kernel
def range_for_continue() -> ti.i32:
cnt = 0
for i in range(n * n):
if x[i]: continue
cnt += 1
return cnt

init()
assert struct_for_continue() == n * (n - 1)
assert range_for_continue() == n * (n - 1)

0 comments on commit d4af2f8

Please sign in to comment.