Skip to content

Commit

Permalink
[Lang] Support continue on all backends (#716)
Browse files Browse the repository at this point in the history
* [Lang] Support `continue` on all backends

* [skip ci] enforce code format

* move

* [skip ci] assert != nullptr

* [skip ci] print scope

* [skip ci] more asserts

* comments

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Apr 7, 2020
1 parent 213533d commit d28533c
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 42 deletions.
4 changes: 1 addition & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,7 @@ def visit_Break(self, node):
return self.parse_stmt('ti.core.insert_break_stmt()')

def visit_Continue(self, node):
raise TaichiSyntaxError(
'"continue" is not yet supported in Taichi kernels. Please walk around by changing loop conditions.'
)
return self.parse_stmt('ti.core.insert_continue_stmt()')

def visit_Call(self, node):
if not (isinstance(node.func, ast.Attribute)
Expand Down
99 changes: 82 additions & 17 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void OffloadedTask::compile() {
func = (task_fp_type)kernel_symbol;
}

// FunctionCreationGuard

// TODO(k-ye): Hide FunctionCreationGuard inside cpp file
FunctionCreationGuard::FunctionCreationGuard(
CodeGenLLVM *mb,
std::vector<llvm::Type *> arguments)
Expand Down Expand Up @@ -75,6 +74,45 @@ FunctionCreationGuard::~FunctionCreationGuard() {
}
}

namespace {

class CodeGenStmtGuard {
public:
using Getter = std::function<llvm::BasicBlock *(void)>;
using Setter = std::function<void(llvm::BasicBlock *)>;

explicit CodeGenStmtGuard(Getter getter, Setter setter)
: saved_stmt_(getter()), setter_(std::move(setter)) {
}

~CodeGenStmtGuard() {
setter_(saved_stmt_);
}

CodeGenStmtGuard(CodeGenStmtGuard &&) = default;
CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default;

private:
llvm::BasicBlock *saved_stmt_;
Setter setter_;
};

CodeGenStmtGuard make_loop_reentry_guard(CodeGenLLVM *cg) {
return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; },
[cg](llvm::BasicBlock *saved_stmt) {
cg->current_loop_reentry = saved_stmt;
});
}

CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) {
return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; },
[cg](llvm::BasicBlock *saved_stmt) {
cg->current_while_after_loop = saved_stmt;
});
}

} // namespace

// CodeGenLLVM

void CodeGenLLVM::visit(Block *stmt_list) {
Expand Down Expand Up @@ -664,29 +702,44 @@ void CodeGenLLVM::visit(ConstStmt *stmt) {
void CodeGenLLVM::visit(WhileControlStmt *stmt) {
BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(while_after_loop);
TI_ASSERT(current_while_after_loop);
auto cond =
builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0));
builder->CreateCondBr(cond, while_after_loop, after_break);
builder->CreateCondBr(cond, current_while_after_loop, after_break);
builder->SetInsertPoint(after_break);
}

void CodeGenLLVM::visit(ContinueStmt *stmt) {
if (stmt->as_return()) {
builder->CreateRetVoid();
} else {
TI_ASSERT(current_loop_reentry != nullptr);
builder->CreateBr(current_loop_reentry);
}
// Stmts after continue are useless, so we switch the insertion point to
// /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!".
BasicBlock *after_continue =
BasicBlock::Create(*llvm_context, "after_continue", func);
builder->SetInsertPoint(after_continue);
}

void CodeGenLLVM::visit(WhileStmt *stmt) {
BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
builder->CreateBr(body);
builder->SetInsertPoint(body);
auto lrg = make_loop_reentry_guard(this);
current_loop_reentry = body;

BasicBlock *after_loop =
BasicBlock::Create(*llvm_context, "after_while", func);
auto old_while_after_loop = while_after_loop;
while_after_loop = after_loop;
auto walg = make_while_after_loop_guard(this);
current_while_after_loop = after_loop;

stmt->body->accept(this);

builder->CreateBr(body); // jump to head

builder->SetInsertPoint(after_loop);
while_after_loop = old_while_after_loop;
}

llvm::Value *CodeGenLLVM::cast_pointer(llvm::Value *val,
Expand Down Expand Up @@ -734,9 +787,12 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
}

void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock *body = BasicBlock::Create(*llvm_context, "loop_body", func);
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "block", func);
BasicBlock *test = BasicBlock::Create(*llvm_context, "test", func);
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
BasicBlock *loop_inc =
BasicBlock::Create(*llvm_context, "for_loop_inc", func);
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);
if (!for_stmt->reversed) {
builder->CreateStore(llvm_val[for_stmt->begin],
llvm_val[for_stmt->loop_var]);
Expand All @@ -745,11 +801,11 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
llvm_val[for_stmt->loop_var]);
}
builder->CreateBr(test);
builder->CreateBr(loop_test);

{
// test block
builder->SetInsertPoint(test);
builder->SetInsertPoint(loop_test);
llvm::Value *cond;
if (!for_stmt->reversed) {
cond =
Expand All @@ -766,17 +822,25 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
}

{
// body cfg
builder->SetInsertPoint(body);
{
auto lrg = make_loop_reentry_guard(this);
// The continue stmt should jump to the loop-increment block!
current_loop_reentry = loop_inc;
// body cfg
builder->SetInsertPoint(body);

for_stmt->body->accept(this);
}

for_stmt->body->accept(this);
builder->CreateBr(loop_inc);
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1));
} else {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1));
}
builder->CreateBr(test);
builder->CreateBr(loop_test);
}

// next cfg
Expand Down Expand Up @@ -1150,7 +1214,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {

std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
std::string suffix) {
while_after_loop = nullptr;
current_loop_reentry = nullptr;
current_while_after_loop = nullptr;
current_offloaded_stmt = stmt;

task_function_type =
Expand Down
7 changes: 6 additions & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
llvm::Type *context_ty;
llvm::Type *physical_coordinate_ty;
llvm::Value *current_coordinates;
llvm::BasicBlock *while_after_loop;
// Mainly for supporting continue stmt
llvm::BasicBlock *current_loop_reentry;
// Mainly for supporting break stmt
llvm::BasicBlock *current_while_after_loop;
llvm::FunctionType *task_function_type;
OffloadedStmt *current_offloaded_stmt;
SNodeAttributes &snode_attr;
Expand Down Expand Up @@ -154,6 +157,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {

void visit(WhileControlStmt *stmt) override;

void visit(ContinueStmt *stmt) override;

void visit(WhileStmt *stmt) override;

void visit(RangeForStmt *for_stmt) override;
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ class KernelCodegen : public IRVisitor {
emit("if (!{}) break;", stmt->cond->raw_name());
}

void visit(ContinueStmt *stmt) override {
if (stmt->as_return()) {
emit("return;");
} else {
emit("continue;");
}
}

void visit(WhileStmt *stmt) override {
emit("while (true) {{");
stmt->body->accept(this);
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,14 @@ class KernelGen : public IRVisitor {
emit("if ({} == 0) break;", stmt->cond->short_name());
}

void visit(ContinueStmt *stmt) override {
if (stmt->as_return()) {
emit("return;");
} else {
emit("continue;");
}
}

void visit(WhileStmt *stmt) override {
emit("while (true) {{");
stmt->body->accept(this);
Expand Down
4 changes: 3 additions & 1 deletion taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ PER_STATEMENT(FrontendForStmt)
PER_STATEMENT(FrontendPrintStmt)
PER_STATEMENT(FrontendWhileStmt)
PER_STATEMENT(FrontendBreakStmt)
PER_STATEMENT(FrontendContinueStmt)
PER_STATEMENT(FrontendAllocaStmt)
PER_STATEMENT(FrontendAssignStmt)
PER_STATEMENT(FrontendAtomicStmt)
PER_STATEMENT(FrontendEvalStmt)
PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendArgStoreStmt)
PER_STATEMENT(FrontendFuncDefStmt)

// Middle-end statement

Expand All @@ -20,8 +22,8 @@ PER_STATEMENT(StructForStmt)
PER_STATEMENT(IfStmt)
PER_STATEMENT(WhileStmt)
PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncBodyStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FuncCallStmt)

PER_STATEMENT(ArgLoadStmt)
Expand Down
11 changes: 11 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "taichi/ir/frontend.h"
#include "taichi/ir/statements.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -546,4 +547,14 @@ std::string OffloadedStmt::task_type_name(TaskType tt) {
return m.find(tt)->second;
}

bool ContinueStmt::as_return() const {
TI_ASSERT(scope != nullptr);
if (auto *offl = scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return true;
}
return false;
}

TLANG_NAMESPACE_END
49 changes: 49 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
bool same_statements(IRNode *root1, IRNode *root2);
std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
void compile_to_offloads(IRNode *ir,
CompileConfig config,
bool vectorize,
Expand Down Expand Up @@ -941,6 +942,43 @@ class WhileControlStmt : public Stmt {
DEFINE_ACCEPT;
};

class ContinueStmt : public Stmt {
public:
// This is the loop on which this continue stmt has effects. It can be either
// an offloaded task, or a for/while loop inside the kernel.
Stmt *scope;

ContinueStmt() : scope(nullptr) {
TI_STMT_REG_FIELDS;
}

// For top-level loops, since they are parallelized to multiple threads (on
// either CPU or GPU), `continue` becomes semantically equivalent to `return`.
//
// Caveat:
// We should wrap each backend's kernel body into a function (as LLVM does).
// The reason is that, each thread may handle more than one element,
// depending on the backend's implementation.
//
// For example, CUDA uses gride-stride loops, the snippet below illustrates
// the idea:
//
// __global__ foo_kernel(...) {
// for (int i = lower; i < upper; i += gridDim) {
// auto coord = compute_coords(i);
// // run_foo_kernel is produced by codegen
// run_foo_kernel(coord);
// }
// }
//
// If run_foo_kernel() is directly inlined within foo_kernel(), `return`
// could prematurely terminate the entire kernel.
bool as_return() const;

TI_STMT_DEF_FIELDS(scope);
DEFINE_ACCEPT;
};

class UnaryOpStmt : public Stmt {
public:
UnaryOpType op_type;
Expand Down Expand Up @@ -2072,6 +2110,17 @@ class FrontendBreakStmt : public Stmt {
DEFINE_ACCEPT
};

class FrontendContinueStmt : public Stmt {
public:
FrontendContinueStmt() = default;

bool is_container_statement() const override {
return false;
}

DEFINE_ACCEPT
};

class FrontendWhileStmt : public Stmt {
public:
Expr cond;
Expand Down
4 changes: 4 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ void export_lang(py::module &m) {
current_ast_builder().insert(Stmt::make<FrontendBreakStmt>());
});

m.def("insert_continue_stmt", [&]() {
current_ast_builder().insert(Stmt::make<FrontendContinueStmt>());
});

m.def("begin_func", [&](const std::string &funcid) {
auto stmt_unique = std::make_unique<FrontendFuncDefStmt>(funcid);
auto stmt = stmt_unique.get();
Expand Down
12 changes: 12 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class IRPrinter : public IRVisitor {
print("break");
}

void visit(FrontendContinueStmt *stmt) override {
print("continue");
}

void visit(FrontendAssignStmt *assign) override {
print("{} = {}", assign->lhs->serialize(), assign->rhs->serialize());
}
Expand Down Expand Up @@ -202,6 +206,14 @@ class IRPrinter : public IRVisitor {
print("while control {}, {}", stmt->mask->name(), stmt->cond->name());
}

void visit(ContinueStmt *stmt) override {
if (stmt->scope) {
print("{} continue (scope={})", stmt->name(), stmt->name());
} else {
print("{} continue", stmt->name());
}
}

void visit(FuncCallStmt *stmt) override {
print("{}{} = call \"{}\"", stmt->type_hint(), stmt->name(), stmt->funcid);
}
Expand Down
Loading

0 comments on commit d28533c

Please sign in to comment.