Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support continue on all backends #716

Merged
merged 7 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,7 @@ def visit_Break(self, node):
return self.parse_stmt('ti.core.insert_break_stmt()')

def visit_Continue(self, node):
raise TaichiSyntaxError(
'"continue" is not yet supported in Taichi kernels. Please walk around by changing loop conditions.'
)
return self.parse_stmt('ti.core.insert_continue_stmt()')

def visit_Call(self, node):
if not (isinstance(node.func, ast.Attribute)
Expand Down
99 changes: 82 additions & 17 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void OffloadedTask::compile() {
func = (task_fp_type)kernel_symbol;
}

// FunctionCreationGuard

// TODO(k-ye): Hide FunctionCreationGuard inside cpp file
FunctionCreationGuard::FunctionCreationGuard(
CodeGenLLVM *mb,
std::vector<llvm::Type *> arguments)
Expand Down Expand Up @@ -75,6 +74,45 @@ FunctionCreationGuard::~FunctionCreationGuard() {
}
}

namespace {

class CodeGenStmtGuard {
public:
using Getter = std::function<llvm::BasicBlock *(void)>;
using Setter = std::function<void(llvm::BasicBlock *)>;

explicit CodeGenStmtGuard(Getter getter, Setter setter)
: saved_stmt_(getter()), setter_(std::move(setter)) {
}

~CodeGenStmtGuard() {
setter_(saved_stmt_);
}

CodeGenStmtGuard(CodeGenStmtGuard &&) = default;
CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default;

private:
llvm::BasicBlock *saved_stmt_;
Setter setter_;
};

CodeGenStmtGuard make_loop_reentry_guard(CodeGenLLVM *cg) {
return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; },
[cg](llvm::BasicBlock *saved_stmt) {
cg->current_loop_reentry = saved_stmt;
});
}

CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) {
return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; },
[cg](llvm::BasicBlock *saved_stmt) {
cg->current_while_after_loop = saved_stmt;
});
}

} // namespace

// CodeGenLLVM

void CodeGenLLVM::visit(Block *stmt_list) {
Expand Down Expand Up @@ -664,29 +702,44 @@ void CodeGenLLVM::visit(ConstStmt *stmt) {
void CodeGenLLVM::visit(WhileControlStmt *stmt) {
BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(while_after_loop);
TI_ASSERT(current_while_after_loop);
auto cond =
builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0));
builder->CreateCondBr(cond, while_after_loop, after_break);
builder->CreateCondBr(cond, current_while_after_loop, after_break);
builder->SetInsertPoint(after_break);
}

void CodeGenLLVM::visit(ContinueStmt *stmt) {
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
if (stmt->as_return()) {
builder->CreateRetVoid();
} else {
TI_ASSERT(current_loop_reentry != nullptr);
builder->CreateBr(current_loop_reentry);
}
// Stmts after continue are useless, so we switch the insertion point to
// /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!".
BasicBlock *after_continue =
BasicBlock::Create(*llvm_context, "after_continue", func);
builder->SetInsertPoint(after_continue);
}

void CodeGenLLVM::visit(WhileStmt *stmt) {
BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func);
builder->CreateBr(body);
builder->SetInsertPoint(body);
auto lrg = make_loop_reentry_guard(this);
current_loop_reentry = body;

BasicBlock *after_loop =
BasicBlock::Create(*llvm_context, "after_while", func);
auto old_while_after_loop = while_after_loop;
while_after_loop = after_loop;
auto walg = make_while_after_loop_guard(this);
current_while_after_loop = after_loop;

stmt->body->accept(this);

builder->CreateBr(body); // jump to head

builder->SetInsertPoint(after_loop);
while_after_loop = old_while_after_loop;
}

llvm::Value *CodeGenLLVM::cast_pointer(llvm::Value *val,
Expand Down Expand Up @@ -734,9 +787,12 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
}

void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
BasicBlock *body = BasicBlock::Create(*llvm_context, "loop_body", func);
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "block", func);
BasicBlock *test = BasicBlock::Create(*llvm_context, "test", func);
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
BasicBlock *loop_inc =
BasicBlock::Create(*llvm_context, "for_loop_inc", func);
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);
if (!for_stmt->reversed) {
builder->CreateStore(llvm_val[for_stmt->begin],
llvm_val[for_stmt->loop_var]);
Expand All @@ -745,11 +801,11 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)),
llvm_val[for_stmt->loop_var]);
}
builder->CreateBr(test);
builder->CreateBr(loop_test);

{
// test block
builder->SetInsertPoint(test);
builder->SetInsertPoint(loop_test);
llvm::Value *cond;
if (!for_stmt->reversed) {
cond =
Expand All @@ -766,17 +822,25 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
}

{
// body cfg
builder->SetInsertPoint(body);
{
auto lrg = make_loop_reentry_guard(this);
// The continue stmt should jump to the loop-increment block!
current_loop_reentry = loop_inc;
// body cfg
builder->SetInsertPoint(body);

for_stmt->body->accept(this);
}

for_stmt->body->accept(this);
builder->CreateBr(loop_inc);
builder->SetInsertPoint(loop_inc);

if (!for_stmt->reversed) {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1));
} else {
create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1));
}
builder->CreateBr(test);
builder->CreateBr(loop_test);
}

// next cfg
Expand Down Expand Up @@ -1150,7 +1214,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {

std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
std::string suffix) {
while_after_loop = nullptr;
current_loop_reentry = nullptr;
current_while_after_loop = nullptr;
current_offloaded_stmt = stmt;

task_function_type =
Expand Down
7 changes: 6 additions & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
llvm::Type *context_ty;
llvm::Type *physical_coordinate_ty;
llvm::Value *current_coordinates;
llvm::BasicBlock *while_after_loop;
// Mainly for supporting continue stmt
llvm::BasicBlock *current_loop_reentry;
// Mainly for supporting break stmt
llvm::BasicBlock *current_while_after_loop;
llvm::FunctionType *task_function_type;
OffloadedStmt *current_offloaded_stmt;
SNodeAttributes &snode_attr;
Expand Down Expand Up @@ -154,6 +157,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {

void visit(WhileControlStmt *stmt) override;

void visit(ContinueStmt *stmt) override;

void visit(WhileStmt *stmt) override;

void visit(RangeForStmt *for_stmt) override;
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ class KernelCodegen : public IRVisitor {
emit("if (!{}) break;", stmt->cond->raw_name());
}

void visit(ContinueStmt *stmt) override {
if (stmt->as_return()) {
emit("return;");
} else {
emit("continue;");
}
}

void visit(WhileStmt *stmt) override {
emit("while (true) {{");
stmt->body->accept(this);
Expand Down
8 changes: 8 additions & 0 deletions taichi/codegen/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,14 @@ class KernelGen : public IRVisitor {
emit("if ({} == 0) break;", stmt->cond->short_name());
}

void visit(ContinueStmt *stmt) override {
k-ye marked this conversation as resolved.
Show resolved Hide resolved
if (stmt->as_return()) {
emit("return;");
} else {
emit("continue;");
}
}

void visit(WhileStmt *stmt) override {
emit("while (true) {{");
stmt->body->accept(this);
Expand Down
4 changes: 3 additions & 1 deletion taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ PER_STATEMENT(FrontendForStmt)
PER_STATEMENT(FrontendPrintStmt)
PER_STATEMENT(FrontendWhileStmt)
PER_STATEMENT(FrontendBreakStmt)
PER_STATEMENT(FrontendContinueStmt)
PER_STATEMENT(FrontendAllocaStmt)
PER_STATEMENT(FrontendAssignStmt)
PER_STATEMENT(FrontendAtomicStmt)
PER_STATEMENT(FrontendEvalStmt)
PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendArgStoreStmt)
PER_STATEMENT(FrontendFuncDefStmt)

// Middle-end statement

Expand All @@ -20,8 +22,8 @@ PER_STATEMENT(StructForStmt)
PER_STATEMENT(IfStmt)
PER_STATEMENT(WhileStmt)
PER_STATEMENT(WhileControlStmt)
PER_STATEMENT(ContinueStmt)
PER_STATEMENT(FuncBodyStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FuncCallStmt)

PER_STATEMENT(ArgLoadStmt)
Expand Down
11 changes: 11 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "taichi/ir/frontend.h"
#include "taichi/ir/statements.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -546,4 +547,14 @@ std::string OffloadedStmt::task_type_name(TaskType tt) {
return m.find(tt)->second;
}

bool ContinueStmt::as_return() const {
TI_ASSERT(scope != nullptr);
if (auto *offl = scope->cast<OffloadedStmt>(); offl) {
TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
offl->task_type == OffloadedStmt::TaskType::struct_for);
return true;
}
return false;
}

TLANG_NAMESPACE_END
28 changes: 28 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
bool same_statements(IRNode *root1, IRNode *root2);
std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
void compile_to_offloads(IRNode *ir,
CompileConfig config,
bool vectorize,
Expand Down Expand Up @@ -940,6 +941,22 @@ class WhileControlStmt : public Stmt {
DEFINE_ACCEPT;
};

class ContinueStmt : public Stmt {
public:
// This is the loop on which this continue stmt has effects. It can be either
// an offloaded task, or a for/while loop inside the kernel.
Stmt *scope;

ContinueStmt() : scope(nullptr) {
TI_STMT_REG_FIELDS;
}

bool as_return() const;

TI_STMT_DEF_FIELDS(scope);
DEFINE_ACCEPT;
};

class UnaryOpStmt : public Stmt {
public:
UnaryOpType op_type;
Expand Down Expand Up @@ -2067,6 +2084,17 @@ class FrontendBreakStmt : public Stmt {
DEFINE_ACCEPT
};

class FrontendContinueStmt : public Stmt {
public:
FrontendContinueStmt() = default;

bool is_container_statement() const override {
return false;
}

DEFINE_ACCEPT
};

class FrontendWhileStmt : public Stmt {
public:
Expr cond;
Expand Down
4 changes: 4 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ void export_lang(py::module &m) {
current_ast_builder().insert(Stmt::make<FrontendBreakStmt>());
});

m.def("insert_continue_stmt", [&]() {
current_ast_builder().insert(Stmt::make<FrontendContinueStmt>());
});

m.def("begin_func", [&](const std::string &funcid) {
auto stmt_unique = std::make_unique<FrontendFuncDefStmt>(funcid);
auto stmt = stmt_unique.get();
Expand Down
12 changes: 12 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class IRPrinter : public IRVisitor {
print("break");
}

void visit(FrontendContinueStmt *stmt) override {
print("continue");
}

void visit(FrontendAssignStmt *assign) override {
print("{} = {}", assign->lhs->serialize(), assign->rhs->serialize());
}
Expand Down Expand Up @@ -202,6 +206,14 @@ class IRPrinter : public IRVisitor {
print("while control {}, {}", stmt->mask->name(), stmt->cond->name());
}

void visit(ContinueStmt *stmt) override {
if (stmt->scope) {
print("{} continue (scope={})", stmt->name(), stmt->name());
k-ye marked this conversation as resolved.
Show resolved Hide resolved
} else {
print("{} continue", stmt->name());
}
}

void visit(FuncCallStmt *stmt) override {
print("{}{} = call \"{}\"", stmt->type_hint(), stmt->name(), stmt->funcid);
}
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ class LowerAST : public IRVisitor {
throw IRModified();
}

void visit(FrontendContinueStmt *stmt) override {
stmt->parent->replace_with(stmt, Stmt::make<ContinueStmt>());
}

void visit(FrontendWhileStmt *stmt) override {
// transform into a structure as
// while (1) { cond; if (no active) break; original body...}
Expand Down
4 changes: 4 additions & 0 deletions taichi/transforms/make_adjoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ class MakeAdjoint : public IRVisitor {
TI_NOT_IMPLEMENTED
}

void visit(ContinueStmt *stmt) override {
TI_NOT_IMPLEMENTED;
}

void visit(WhileStmt *stmt) override {
TI_NOT_IMPLEMENTED
}
Expand Down
Loading