From d28533c503a1e4fd1101b549d63c6fb540618be2 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 7 Apr 2020 19:12:48 +0900 Subject: [PATCH] [Lang] Support `continue` on all backends (#716) * [Lang] Support `continue` on all backends * [skip ci] enforce code format * move * [skip ci] assert != nullptr * [skip ci] print scope * [skip ci] more asserts * comments Co-authored-by: Taichi Gardener --- python/taichi/lang/transformer.py | 4 +- taichi/codegen/codegen_llvm.cpp | 99 +++++++++++++++++++++++++----- taichi/codegen/codegen_llvm.h | 7 ++- taichi/codegen/codegen_metal.cpp | 8 +++ taichi/codegen/codegen_opengl.cpp | 8 +++ taichi/inc/statements.inc.h | 4 +- taichi/ir/ir.cpp | 11 ++++ taichi/ir/ir.h | 49 +++++++++++++++ taichi/python/export_lang.cpp | 4 ++ taichi/transforms/ir_printer.cpp | 12 ++++ taichi/transforms/lower_ast.cpp | 4 ++ taichi/transforms/make_adjoint.cpp | 4 ++ taichi/transforms/offload.cpp | 79 ++++++++++++++++++++++-- taichi/transforms/simplify.cpp | 4 ++ tests/python/test_continue.py | 84 +++++++++++++++++++++++++ tests/python/test_syntax_errors.py | 16 ----- 16 files changed, 355 insertions(+), 42 deletions(-) create mode 100644 tests/python/test_continue.py diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index b38b1773be9d0..cbab3a94e7ed7 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -534,9 +534,7 @@ def visit_Break(self, node): return self.parse_stmt('ti.core.insert_break_stmt()') def visit_Continue(self, node): - raise TaichiSyntaxError( - '"continue" is not yet supported in Taichi kernels. Please walk around by changing loop conditions.' - ) + return self.parse_stmt('ti.core.insert_continue_stmt()') def visit_Call(self, node): if not (isinstance(node.func, ast.Attribute) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 03fdfb221a6d8..34061cbab088e 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -31,8 +31,7 @@ void OffloadedTask::compile() { func = (task_fp_type)kernel_symbol; } -// FunctionCreationGuard - +// TODO(k-ye): Hide FunctionCreationGuard inside cpp file FunctionCreationGuard::FunctionCreationGuard( CodeGenLLVM *mb, std::vector arguments) @@ -75,6 +74,45 @@ FunctionCreationGuard::~FunctionCreationGuard() { } } +namespace { + +class CodeGenStmtGuard { + public: + using Getter = std::function; + using Setter = std::function; + + explicit CodeGenStmtGuard(Getter getter, Setter setter) + : saved_stmt_(getter()), setter_(std::move(setter)) { + } + + ~CodeGenStmtGuard() { + setter_(saved_stmt_); + } + + CodeGenStmtGuard(CodeGenStmtGuard &&) = default; + CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default; + + private: + llvm::BasicBlock *saved_stmt_; + Setter setter_; +}; + +CodeGenStmtGuard make_loop_reentry_guard(CodeGenLLVM *cg) { + return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; }, + [cg](llvm::BasicBlock *saved_stmt) { + cg->current_loop_reentry = saved_stmt; + }); +} + +CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) { + return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; }, + [cg](llvm::BasicBlock *saved_stmt) { + cg->current_while_after_loop = saved_stmt; + }); +} + +} // namespace + // CodeGenLLVM void CodeGenLLVM::visit(Block *stmt_list) { @@ -664,29 +702,44 @@ void CodeGenLLVM::visit(ConstStmt *stmt) { void CodeGenLLVM::visit(WhileControlStmt *stmt) { BasicBlock *after_break = BasicBlock::Create(*llvm_context, "after_break", func); - TI_ASSERT(while_after_loop); + TI_ASSERT(current_while_after_loop); auto cond = builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0)); - builder->CreateCondBr(cond, while_after_loop, after_break); + builder->CreateCondBr(cond, current_while_after_loop, after_break); builder->SetInsertPoint(after_break); } +void CodeGenLLVM::visit(ContinueStmt *stmt) { + if (stmt->as_return()) { + builder->CreateRetVoid(); + } else { + TI_ASSERT(current_loop_reentry != nullptr); + builder->CreateBr(current_loop_reentry); + } + // Stmts after continue are useless, so we switch the insertion point to + // /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!". + BasicBlock *after_continue = + BasicBlock::Create(*llvm_context, "after_continue", func); + builder->SetInsertPoint(after_continue); +} + void CodeGenLLVM::visit(WhileStmt *stmt) { BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func); builder->CreateBr(body); builder->SetInsertPoint(body); + auto lrg = make_loop_reentry_guard(this); + current_loop_reentry = body; BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_while", func); - auto old_while_after_loop = while_after_loop; - while_after_loop = after_loop; + auto walg = make_while_after_loop_guard(this); + current_while_after_loop = after_loop; stmt->body->accept(this); builder->CreateBr(body); // jump to head builder->SetInsertPoint(after_loop); - while_after_loop = old_while_after_loop; } llvm::Value *CodeGenLLVM::cast_pointer(llvm::Value *val, @@ -734,9 +787,12 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) { } void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { - BasicBlock *body = BasicBlock::Create(*llvm_context, "loop_body", func); - BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "block", func); - BasicBlock *test = BasicBlock::Create(*llvm_context, "test", func); + BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func); + BasicBlock *loop_inc = + BasicBlock::Create(*llvm_context, "for_loop_inc", func); + BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func); + BasicBlock *loop_test = + BasicBlock::Create(*llvm_context, "for_loop_test", func); if (!for_stmt->reversed) { builder->CreateStore(llvm_val[for_stmt->begin], llvm_val[for_stmt->loop_var]); @@ -745,11 +801,11 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)), llvm_val[for_stmt->loop_var]); } - builder->CreateBr(test); + builder->CreateBr(loop_test); { // test block - builder->SetInsertPoint(test); + builder->SetInsertPoint(loop_test); llvm::Value *cond; if (!for_stmt->reversed) { cond = @@ -766,17 +822,25 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { } { - // body cfg - builder->SetInsertPoint(body); + { + auto lrg = make_loop_reentry_guard(this); + // The continue stmt should jump to the loop-increment block! + current_loop_reentry = loop_inc; + // body cfg + builder->SetInsertPoint(body); + + for_stmt->body->accept(this); + } - for_stmt->body->accept(this); + builder->CreateBr(loop_inc); + builder->SetInsertPoint(loop_inc); if (!for_stmt->reversed) { create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1)); } else { create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1)); } - builder->CreateBr(test); + builder->CreateBr(loop_test); } // next cfg @@ -1150,7 +1214,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, std::string suffix) { - while_after_loop = nullptr; + current_loop_reentry = nullptr; + current_while_after_loop = nullptr; current_offloaded_stmt = stmt; task_function_type = diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 163bf34d5d53c..df9404018628e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -57,7 +57,10 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::Type *context_ty; llvm::Type *physical_coordinate_ty; llvm::Value *current_coordinates; - llvm::BasicBlock *while_after_loop; + // Mainly for supporting continue stmt + llvm::BasicBlock *current_loop_reentry; + // Mainly for supporting break stmt + llvm::BasicBlock *current_while_after_loop; llvm::FunctionType *task_function_type; OffloadedStmt *current_offloaded_stmt; SNodeAttributes &snode_attr; @@ -154,6 +157,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(WhileControlStmt *stmt) override; + void visit(ContinueStmt *stmt) override; + void visit(WhileStmt *stmt) override; void visit(RangeForStmt *for_stmt) override; diff --git a/taichi/codegen/codegen_metal.cpp b/taichi/codegen/codegen_metal.cpp index f5f70b2d6c23c..68d6a1d5a4a29 100644 --- a/taichi/codegen/codegen_metal.cpp +++ b/taichi/codegen/codegen_metal.cpp @@ -466,6 +466,14 @@ class KernelCodegen : public IRVisitor { emit("if (!{}) break;", stmt->cond->raw_name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->as_return()) { + emit("return;"); + } else { + emit("continue;"); + } + } + void visit(WhileStmt *stmt) override { emit("while (true) {{"); stmt->body->accept(this); diff --git a/taichi/codegen/codegen_opengl.cpp b/taichi/codegen/codegen_opengl.cpp index bb7977f072db4..268938ba63148 100644 --- a/taichi/codegen/codegen_opengl.cpp +++ b/taichi/codegen/codegen_opengl.cpp @@ -681,6 +681,14 @@ class KernelGen : public IRVisitor { emit("if ({} == 0) break;", stmt->cond->short_name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->as_return()) { + emit("return;"); + } else { + emit("continue;"); + } + } + void visit(WhileStmt *stmt) override { emit("while (true) {{"); stmt->body->accept(this); diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index b31bececc37d7..49e5277afd074 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -4,6 +4,7 @@ PER_STATEMENT(FrontendForStmt) PER_STATEMENT(FrontendPrintStmt) PER_STATEMENT(FrontendWhileStmt) PER_STATEMENT(FrontendBreakStmt) +PER_STATEMENT(FrontendContinueStmt) PER_STATEMENT(FrontendAllocaStmt) PER_STATEMENT(FrontendAssignStmt) PER_STATEMENT(FrontendAtomicStmt) @@ -11,6 +12,7 @@ PER_STATEMENT(FrontendEvalStmt) PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear PER_STATEMENT(FrontendAssertStmt) PER_STATEMENT(FrontendArgStoreStmt) +PER_STATEMENT(FrontendFuncDefStmt) // Middle-end statement @@ -20,8 +22,8 @@ PER_STATEMENT(StructForStmt) PER_STATEMENT(IfStmt) PER_STATEMENT(WhileStmt) PER_STATEMENT(WhileControlStmt) +PER_STATEMENT(ContinueStmt) PER_STATEMENT(FuncBodyStmt) -PER_STATEMENT(FrontendFuncDefStmt) PER_STATEMENT(FuncCallStmt) PER_STATEMENT(ArgLoadStmt) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index d8d58b1557fcb..18fe0e67c7ebd 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -7,6 +7,7 @@ #include #include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" TLANG_NAMESPACE_BEGIN @@ -546,4 +547,14 @@ std::string OffloadedStmt::task_type_name(TaskType tt) { return m.find(tt)->second; } +bool ContinueStmt::as_return() const { + TI_ASSERT(scope != nullptr); + if (auto *offl = scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index b85d43cbb7654..39ef951c0d359 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -116,6 +116,7 @@ std::vector gather_statements(IRNode *root, const std::function &test); bool same_statements(IRNode *root1, IRNode *root2); std::unordered_set detect_fors_with_break(IRNode *root); +std::unordered_set detect_loops_with_continue(IRNode *root); void compile_to_offloads(IRNode *ir, CompileConfig config, bool vectorize, @@ -941,6 +942,43 @@ class WhileControlStmt : public Stmt { DEFINE_ACCEPT; }; +class ContinueStmt : public Stmt { + public: + // This is the loop on which this continue stmt has effects. It can be either + // an offloaded task, or a for/while loop inside the kernel. + Stmt *scope; + + ContinueStmt() : scope(nullptr) { + TI_STMT_REG_FIELDS; + } + + // For top-level loops, since they are parallelized to multiple threads (on + // either CPU or GPU), `continue` becomes semantically equivalent to `return`. + // + // Caveat: + // We should wrap each backend's kernel body into a function (as LLVM does). + // The reason is that, each thread may handle more than one element, + // depending on the backend's implementation. + // + // For example, CUDA uses gride-stride loops, the snippet below illustrates + // the idea: + // + // __global__ foo_kernel(...) { + // for (int i = lower; i < upper; i += gridDim) { + // auto coord = compute_coords(i); + // // run_foo_kernel is produced by codegen + // run_foo_kernel(coord); + // } + // } + // + // If run_foo_kernel() is directly inlined within foo_kernel(), `return` + // could prematurely terminate the entire kernel. + bool as_return() const; + + TI_STMT_DEF_FIELDS(scope); + DEFINE_ACCEPT; +}; + class UnaryOpStmt : public Stmt { public: UnaryOpType op_type; @@ -2072,6 +2110,17 @@ class FrontendBreakStmt : public Stmt { DEFINE_ACCEPT }; +class FrontendContinueStmt : public Stmt { + public: + FrontendContinueStmt() = default; + + bool is_container_statement() const override { + return false; + } + + DEFINE_ACCEPT +}; + class FrontendWhileStmt : public Stmt { public: Expr cond; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 186aa3ba0b74e..d1ffb4f882b5f 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -300,6 +300,10 @@ void export_lang(py::module &m) { current_ast_builder().insert(Stmt::make()); }); + m.def("insert_continue_stmt", [&]() { + current_ast_builder().insert(Stmt::make()); + }); + m.def("begin_func", [&](const std::string &funcid) { auto stmt_unique = std::make_unique(funcid); auto stmt = stmt_unique.get(); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 03c69f3d9f5b1..8f94ab16a031b 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -53,6 +53,10 @@ class IRPrinter : public IRVisitor { print("break"); } + void visit(FrontendContinueStmt *stmt) override { + print("continue"); + } + void visit(FrontendAssignStmt *assign) override { print("{} = {}", assign->lhs->serialize(), assign->rhs->serialize()); } @@ -202,6 +206,14 @@ class IRPrinter : public IRVisitor { print("while control {}, {}", stmt->mask->name(), stmt->cond->name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->scope) { + print("{} continue (scope={})", stmt->name(), stmt->name()); + } else { + print("{} continue", stmt->name()); + } + } + void visit(FuncCallStmt *stmt) override { print("{}{} = call \"{}\"", stmt->type_hint(), stmt->name(), stmt->funcid); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 4efaed77cd2a7..1fe10783ca6d7 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -113,6 +113,10 @@ class LowerAST : public IRVisitor { throw IRModified(); } + void visit(FrontendContinueStmt *stmt) override { + stmt->parent->replace_with(stmt, Stmt::make()); + } + void visit(FrontendWhileStmt *stmt) override { // transform into a structure as // while (1) { cond; if (no active) break; original body...} diff --git a/taichi/transforms/make_adjoint.cpp b/taichi/transforms/make_adjoint.cpp index 56744319dc94e..d195f810b308d 100644 --- a/taichi/transforms/make_adjoint.cpp +++ b/taichi/transforms/make_adjoint.cpp @@ -337,6 +337,10 @@ class MakeAdjoint : public IRVisitor { TI_NOT_IMPLEMENTED } + void visit(ContinueStmt *stmt) override { + TI_NOT_IMPLEMENTED; + } + void visit(WhileStmt *stmt) override { TI_NOT_IMPLEMENTED } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 3b5ebb8ba42ad..1189632bbc03d 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -5,8 +5,8 @@ TLANG_NAMESPACE_BEGIN namespace irpass { namespace { + using StmtToOffsetMap = decltype(OffloadedResult::local_to_global_offset); -} // namespace // Break kernel into multiple parts and emit struct for listgens class Offloader { @@ -404,6 +404,74 @@ void insert_gc(IRNode *root) { } } +class AssociateContinueScope : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + using Parent = BasicStmtVisitor; + + void visit(WhileStmt *stmt) override { + auto *old_loop = cur_internal_loop_; + cur_internal_loop_ = stmt; + Parent::visit(stmt); + cur_internal_loop_ = old_loop; + } + + void visit(RangeForStmt *stmt) override { + auto *old_loop = cur_internal_loop_; + cur_internal_loop_ = stmt; + Parent::visit(stmt); + cur_internal_loop_ = old_loop; + } + + void visit(StructForStmt *stmt) override { + TI_ERROR("struct_for cannot be nested inside a kernel, stmt={}", + stmt->name()); + } + + void visit(OffloadedStmt *stmt) override { + TI_ASSERT(cur_offloaded_stmt_ == nullptr); + TI_ASSERT(cur_internal_loop_ == nullptr); + cur_offloaded_stmt_ = stmt; + Parent::visit(stmt); + cur_offloaded_stmt_ = nullptr; + } + + void visit(ContinueStmt *stmt) override { + if (stmt->scope == nullptr) { + if (cur_internal_loop_ != nullptr) { + stmt->scope = cur_internal_loop_; + } else { + stmt->scope = cur_offloaded_stmt_; + } + modified_ = true; + } + TI_ASSERT(stmt->scope != nullptr); + } + + static void run(IRNode *root) { + while (true) { + AssociateContinueScope pass; + root->accept(&pass); + if (!pass.modified_) { + break; + } + } + } + + private: + explicit AssociateContinueScope() + : modified_(false), + cur_offloaded_stmt_(nullptr), + cur_internal_loop_(nullptr) { + } + + bool modified_; + OffloadedStmt *cur_offloaded_stmt_; + Stmt *cur_internal_loop_; +}; + +} // namespace + OffloadedResult offload(IRNode *root) { OffloadedResult result; Offloader _(root); @@ -414,9 +482,12 @@ OffloadedResult offload(IRNode *root) { PromoteIntermediate::run(root, result.local_to_global_offset); PromoteLocals::run(root, result.local_to_global_offset); } - irpass::insert_gc(root); - irpass::typecheck(root); - irpass::re_id(root); + insert_gc(root); + // TODO(k-ye): Move this into its own pass. However, we need to wait for all + // backends to integrate with https://github.com/taichi-dev/taichi/pull/700 + AssociateContinueScope::run(root); + typecheck(root); + re_id(root); return result; } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 1e363bfd472d2..8e1c68961c1f1 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1066,6 +1066,10 @@ class BasicBlockSimplify : public IRVisitor { return; } + void visit(ContinueStmt *stmt) override { + return; + } + static bool is_global_write(Stmt *stmt) { return stmt->is() || stmt->is(); } diff --git a/tests/python/test_continue.py b/tests/python/test_continue.py new file mode 100644 index 0000000000000..ae533ed1e4e6d --- /dev/null +++ b/tests/python/test_continue.py @@ -0,0 +1,84 @@ +import taichi as ti + +n = 1000 + + +@ti.all_archs +def test_for_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + for j in range(n): + if j % 2 == 0: + continue + x[j] = j + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_while_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + j = 0 + while j < n: + oj = j + j += 1 + if oj % 2 == 0: + continue + x[oj] = oj + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_kernel_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + for i in range(n): + if i % 2 == 0: + # At kernel level, this is the same as return + continue + x[i] = i + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_unconditional_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + for j in range(n): + continue + # pylint: disable=unreachable + x[j] = j + + run() + xs = x.to_numpy() + for i in range(n): + assert xs[i] == 0 diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index 66bae69375a2e..d8f0963547970 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -70,22 +70,6 @@ def func(): func() -@ti.must_throw(ti.TaichiSyntaxError) -def test_continue(): - x = ti.var(ti.f32) - - @ti.layout - def layout(): - ti.root.dense(ti.i, 1).place(x) - - @ti.kernel - def func(): - while True: - continue - - func() - - @ti.must_throw(ti.TaichiSyntaxError) def test_loop_var_range(): x = ti.var(ti.f32)