From 73c62ff653ba4b692eec64596aaa0cc57c7b51cd Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Mon, 6 Apr 2020 19:57:00 +0900 Subject: [PATCH 1/7] [Lang] Support `continue` on all backends --- python/taichi/lang/transformer.py | 4 +- taichi/codegen/codegen_llvm.cpp | 96 ++++++++++++++++++++++++------ taichi/codegen/codegen_llvm.h | 7 ++- taichi/codegen/codegen_metal.cpp | 8 +++ taichi/codegen/codegen_opengl.cpp | 8 +++ taichi/inc/statements.inc.h | 4 +- taichi/ir/ir.h | 29 +++++++++ taichi/python/export_lang.cpp | 51 +++++++++------- taichi/transforms/ir_printer.cpp | 12 ++++ taichi/transforms/lower_ast.cpp | 4 ++ taichi/transforms/make_adjoint.cpp | 25 ++++---- taichi/transforms/offload.cpp | 77 ++++++++++++++++++++++-- taichi/transforms/simplify.cpp | 4 ++ tests/python/test_continue.py | 84 ++++++++++++++++++++++++++ tests/python/test_syntax_errors.py | 16 ----- 15 files changed, 354 insertions(+), 75 deletions(-) create mode 100644 tests/python/test_continue.py diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index b38b1773be9d0..cbab3a94e7ed7 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -534,9 +534,7 @@ def visit_Break(self, node): return self.parse_stmt('ti.core.insert_break_stmt()') def visit_Continue(self, node): - raise TaichiSyntaxError( - '"continue" is not yet supported in Taichi kernels. Please walk around by changing loop conditions.' - ) + return self.parse_stmt('ti.core.insert_continue_stmt()') def visit_Call(self, node): if not (isinstance(node.func, ast.Attribute) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 03fdfb221a6d8..e11b26e48d289 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -31,8 +31,7 @@ void OffloadedTask::compile() { func = (task_fp_type)kernel_symbol; } -// FunctionCreationGuard - +// TODO(k-ye): Hide FunctionCreationGuard inside cpp file FunctionCreationGuard::FunctionCreationGuard( CodeGenLLVM *mb, std::vector arguments) @@ -75,6 +74,42 @@ FunctionCreationGuard::~FunctionCreationGuard() { } } +namespace { + +class CodeGenStmtGuard { + public: + using Getter = std::function; + using Setter = std::function; + + explicit CodeGenStmtGuard(Getter getter, Setter setter) + : saved_stmt_(getter()), setter_(std::move(setter)) { + } + + ~CodeGenStmtGuard() { + setter_(saved_stmt_); + } + + protected: + llvm::BasicBlock *saved_stmt_; + Setter setter_; +}; + +CodeGenStmtGuard make_loop_reentry_guard(CodeGenLLVM *cg) { + return CodeGenStmtGuard([cg]() { return cg->current_loop_reentry; }, + [cg](llvm::BasicBlock *saved_stmt) { + cg->current_loop_reentry = saved_stmt; + }); +} + +CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) { + return CodeGenStmtGuard([cg]() { return cg->current_while_after_loop; }, + [cg](llvm::BasicBlock *saved_stmt) { + cg->current_while_after_loop = saved_stmt; + }); +} + +} // namespace + // CodeGenLLVM void CodeGenLLVM::visit(Block *stmt_list) { @@ -664,29 +699,44 @@ void CodeGenLLVM::visit(ConstStmt *stmt) { void CodeGenLLVM::visit(WhileControlStmt *stmt) { BasicBlock *after_break = BasicBlock::Create(*llvm_context, "after_break", func); - TI_ASSERT(while_after_loop); + TI_ASSERT(current_while_after_loop); auto cond = builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0)); - builder->CreateCondBr(cond, while_after_loop, after_break); + builder->CreateCondBr(cond, current_while_after_loop, after_break); builder->SetInsertPoint(after_break); } +void CodeGenLLVM::visit(ContinueStmt *stmt) { + if (stmt->as_return()) { + builder->CreateRetVoid(); + } else { + TI_ASSERT(current_loop_reentry != nullptr); + builder->CreateBr(current_loop_reentry); + } + // Stmts after continue are useless, so we switch the insertion point to + // /dev/null. In LLVM IR, the "after_continue" label shows "No predecessors!". + BasicBlock *after_continue = + BasicBlock::Create(*llvm_context, "after_continue", func); + builder->SetInsertPoint(after_continue); +} + void CodeGenLLVM::visit(WhileStmt *stmt) { BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func); builder->CreateBr(body); builder->SetInsertPoint(body); + auto lrg = make_loop_reentry_guard(this); + current_loop_reentry = body; BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_while", func); - auto old_while_after_loop = while_after_loop; - while_after_loop = after_loop; + auto walg = make_while_after_loop_guard(this); + current_while_after_loop = after_loop; stmt->body->accept(this); builder->CreateBr(body); // jump to head builder->SetInsertPoint(after_loop); - while_after_loop = old_while_after_loop; } llvm::Value *CodeGenLLVM::cast_pointer(llvm::Value *val, @@ -734,9 +784,12 @@ void CodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) { } void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { - BasicBlock *body = BasicBlock::Create(*llvm_context, "loop_body", func); - BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "block", func); - BasicBlock *test = BasicBlock::Create(*llvm_context, "test", func); + BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func); + BasicBlock *loop_inc = + BasicBlock::Create(*llvm_context, "for_loop_inc", func); + BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func); + BasicBlock *loop_test = + BasicBlock::Create(*llvm_context, "for_loop_test", func); if (!for_stmt->reversed) { builder->CreateStore(llvm_val[for_stmt->begin], llvm_val[for_stmt->loop_var]); @@ -745,11 +798,11 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { builder->CreateSub(llvm_val[for_stmt->end], tlctx->get_constant(1)), llvm_val[for_stmt->loop_var]); } - builder->CreateBr(test); + builder->CreateBr(loop_test); { // test block - builder->SetInsertPoint(test); + builder->SetInsertPoint(loop_test); llvm::Value *cond; if (!for_stmt->reversed) { cond = @@ -766,17 +819,25 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { } { - // body cfg - builder->SetInsertPoint(body); + { + auto lrg = make_loop_reentry_guard(this); + // The continue stmt should jump to the loop-increment block! + current_loop_reentry = loop_inc; + // body cfg + builder->SetInsertPoint(body); + + for_stmt->body->accept(this); + } - for_stmt->body->accept(this); + builder->CreateBr(loop_inc); + builder->SetInsertPoint(loop_inc); if (!for_stmt->reversed) { create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(1)); } else { create_increment(llvm_val[for_stmt->loop_var], tlctx->get_constant(-1)); } - builder->CreateBr(test); + builder->CreateBr(loop_test); } // next cfg @@ -1150,7 +1211,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, std::string suffix) { - while_after_loop = nullptr; + current_loop_reentry = nullptr; + current_while_after_loop = nullptr; current_offloaded_stmt = stmt; task_function_type = diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 163bf34d5d53c..df9404018628e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -57,7 +57,10 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::Type *context_ty; llvm::Type *physical_coordinate_ty; llvm::Value *current_coordinates; - llvm::BasicBlock *while_after_loop; + // Mainly for supporting continue stmt + llvm::BasicBlock *current_loop_reentry; + // Mainly for supporting break stmt + llvm::BasicBlock *current_while_after_loop; llvm::FunctionType *task_function_type; OffloadedStmt *current_offloaded_stmt; SNodeAttributes &snode_attr; @@ -154,6 +157,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { void visit(WhileControlStmt *stmt) override; + void visit(ContinueStmt *stmt) override; + void visit(WhileStmt *stmt) override; void visit(RangeForStmt *for_stmt) override; diff --git a/taichi/codegen/codegen_metal.cpp b/taichi/codegen/codegen_metal.cpp index f5f70b2d6c23c..68d6a1d5a4a29 100644 --- a/taichi/codegen/codegen_metal.cpp +++ b/taichi/codegen/codegen_metal.cpp @@ -466,6 +466,14 @@ class KernelCodegen : public IRVisitor { emit("if (!{}) break;", stmt->cond->raw_name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->as_return()) { + emit("return;"); + } else { + emit("continue;"); + } + } + void visit(WhileStmt *stmt) override { emit("while (true) {{"); stmt->body->accept(this); diff --git a/taichi/codegen/codegen_opengl.cpp b/taichi/codegen/codegen_opengl.cpp index 50ae992951c1d..acf3f3476b127 100644 --- a/taichi/codegen/codegen_opengl.cpp +++ b/taichi/codegen/codegen_opengl.cpp @@ -671,6 +671,14 @@ class KernelGen : public IRVisitor { emit("if ({} == 0) break;", stmt->cond->short_name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->as_return()) { + emit("return;"); + } else { + emit("continue;"); + } + } + void visit(WhileStmt *stmt) override { emit("while (true) {{"); stmt->body->accept(this); diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index b31bececc37d7..49e5277afd074 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -4,6 +4,7 @@ PER_STATEMENT(FrontendForStmt) PER_STATEMENT(FrontendPrintStmt) PER_STATEMENT(FrontendWhileStmt) PER_STATEMENT(FrontendBreakStmt) +PER_STATEMENT(FrontendContinueStmt) PER_STATEMENT(FrontendAllocaStmt) PER_STATEMENT(FrontendAssignStmt) PER_STATEMENT(FrontendAtomicStmt) @@ -11,6 +12,7 @@ PER_STATEMENT(FrontendEvalStmt) PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear PER_STATEMENT(FrontendAssertStmt) PER_STATEMENT(FrontendArgStoreStmt) +PER_STATEMENT(FrontendFuncDefStmt) // Middle-end statement @@ -20,8 +22,8 @@ PER_STATEMENT(StructForStmt) PER_STATEMENT(IfStmt) PER_STATEMENT(WhileStmt) PER_STATEMENT(WhileControlStmt) +PER_STATEMENT(ContinueStmt) PER_STATEMENT(FuncBodyStmt) -PER_STATEMENT(FrontendFuncDefStmt) PER_STATEMENT(FuncCallStmt) PER_STATEMENT(ArgLoadStmt) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index c8884b75174f9..d7281062467c0 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -117,6 +117,7 @@ std::vector gather_statements(IRNode *root, const std::function &test); bool same_statements(IRNode *root1, IRNode *root2); std::unordered_set detect_fors_with_break(IRNode *root); +std::unordered_set detect_loops_with_continue(IRNode *root); void compile_to_offloads(IRNode *ir, CompileConfig config, bool vectorize, @@ -940,6 +941,23 @@ class WhileControlStmt : public Stmt { DEFINE_ACCEPT; }; +class ContinueStmt : public Stmt { + public: + // This is the loop on which this continue stmt has effects. It can be either + // an offloaded task, or a for/while loop inside the kernel. + Stmt *scope; + + ContinueStmt() : scope(nullptr) { + TI_STMT_REG_FIELDS; + } + + bool as_return() const { + return scope->is(); + } + TI_STMT_DEF_FIELDS(scope); + DEFINE_ACCEPT; +}; + class UnaryOpStmt : public Stmt { public: UnaryOpType op_type; @@ -2067,6 +2085,17 @@ class FrontendBreakStmt : public Stmt { DEFINE_ACCEPT }; +class FrontendContinueStmt : public Stmt { + public: + FrontendContinueStmt() = default; + + bool is_container_statement() const override { + return false; + } + + DEFINE_ACCEPT +}; + class FrontendWhileStmt : public Stmt { public: Expr cond; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 186aa3ba0b74e..df3c99628d2e1 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -106,9 +106,10 @@ void export_lang(py::module &m) { m.def("reset_default_compile_config", [&]() { default_compile_config = CompileConfig(); }); - m.def("default_compile_config", - [&]() -> CompileConfig & { return default_compile_config; }, - py::return_value_policy::reference); + m.def( + "default_compile_config", + [&]() -> CompileConfig & { return default_compile_config; }, + py::return_value_policy::reference); py::class_(m, "Program") .def(py::init<>()) @@ -125,11 +126,12 @@ void export_lang(py::module &m) { return (void *)(program->get_profiler()); }) .def("finalize", &Program::finalize) - .def("get_root", - [&](Program *program) -> SNode * { - return program->snode_root.get(); - }, - py::return_value_policy::reference) + .def( + "get_root", + [&](Program *program) -> SNode * { + return program->snode_root.get(); + }, + py::return_value_policy::reference) .def("get_total_compilation_time", &Program::get_total_compilation_time) .def("print_snode_tree", &Program::print_snode_tree) .def("synchronize", &Program::synchronize); @@ -137,9 +139,10 @@ void export_lang(py::module &m) { m.def("get_current_program", get_current_program, py::return_value_policy::reference); - m.def("current_compile_config", - [&]() -> CompileConfig & { return get_current_program().config; }, - py::return_value_policy::reference); + m.def( + "current_compile_config", + [&]() -> CompileConfig & { return get_current_program().config; }, + py::return_value_policy::reference); py::class_(m, "Index").def(py::init()); py::class_(m, "SNode") @@ -169,9 +172,10 @@ void export_lang(py::module &m) { .def("data_type", [](SNode *snode) { return snode->dt; }) .def("get_num_ch", [](SNode *snode) -> int { return (int)snode->ch.size(); }) - .def("get_ch", - [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, - py::return_value_policy::reference) + .def( + "get_ch", + [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, + py::return_value_policy::reference) .def("lazy_grad", &SNode::lazy_grad) .def("read_int", &SNode::read_int) .def("read_uint", &SNode::read_uint) @@ -219,13 +223,14 @@ void export_lang(py::module &m) { py::class_(m, "Stmt"); py::class_(m, "KernelProxy") - .def("define", - [](Program::KernelProxy *ker, - const std::function &func) -> Kernel & { - py::gil_scoped_release release; - return ker->def(func); - }, - py::return_value_policy::reference); + .def( + "define", + [](Program::KernelProxy *ker, + const std::function &func) -> Kernel & { + py::gil_scoped_release release; + return ker->def(func); + }, + py::return_value_policy::reference); m.def("insert_deactivate", [](SNode *snode, const ExprGroup &indices) { return Deactivate(snode, indices); @@ -300,6 +305,10 @@ void export_lang(py::module &m) { current_ast_builder().insert(Stmt::make()); }); + m.def("insert_continue_stmt", [&]() { + current_ast_builder().insert(Stmt::make()); + }); + m.def("begin_func", [&](const std::string &funcid) { auto stmt_unique = std::make_unique(funcid); auto stmt = stmt_unique.get(); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 03c69f3d9f5b1..a3e03fcd31421 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -53,6 +53,10 @@ class IRPrinter : public IRVisitor { print("break"); } + void visit(FrontendContinueStmt *stmt) override { + print("continue"); + } + void visit(FrontendAssignStmt *assign) override { print("{} = {}", assign->lhs->serialize(), assign->rhs->serialize()); } @@ -202,6 +206,14 @@ class IRPrinter : public IRVisitor { print("while control {}, {}", stmt->mask->name(), stmt->cond->name()); } + void visit(ContinueStmt *stmt) override { + if (stmt->scope) { + print("{} continue (as_return={})", stmt->name(), stmt->as_return()); + } else { + print("{} continue", stmt->name()); + } + } + void visit(FuncCallStmt *stmt) override { print("{}{} = call \"{}\"", stmt->type_hint(), stmt->name(), stmt->funcid); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 4efaed77cd2a7..1fe10783ca6d7 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -113,6 +113,10 @@ class LowerAST : public IRVisitor { throw IRModified(); } + void visit(FrontendContinueStmt *stmt) override { + stmt->parent->replace_with(stmt, Stmt::make()); + } + void visit(FrontendWhileStmt *stmt) override { // transform into a structure as // while (1) { cond; if (no active) break; original body...} diff --git a/taichi/transforms/make_adjoint.cpp b/taichi/transforms/make_adjoint.cpp index 56744319dc94e..982933b54ba25 100644 --- a/taichi/transforms/make_adjoint.cpp +++ b/taichi/transforms/make_adjoint.cpp @@ -12,18 +12,15 @@ class ConvertLocalVar : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { TI_ASSERT(alloc->width() == 1); - bool load_only = irpass::gather_statements( - alloc->parent, - [&](Stmt *s) { - if (auto store = s->cast()) - return store->ptr == alloc; - else if (auto atomic = s->cast()) { - return atomic->dest == alloc; - } else { - return false; - } - }) - .empty(); + bool load_only = irpass::gather_statements(alloc->parent, [&](Stmt *s) { + if (auto store = s->cast()) + return store->ptr == alloc; + else if (auto atomic = s->cast()) { + return atomic->dest == alloc; + } else { + return false; + } + }).empty(); if (!load_only) { alloc->replace_with( Stmt::make(alloc->ret_type.data_type, 16)); @@ -337,6 +334,10 @@ class MakeAdjoint : public IRVisitor { TI_NOT_IMPLEMENTED } + void visit(ContinueStmt *stmt) override { + TI_NOT_IMPLEMENTED; + } + void visit(WhileStmt *stmt) override { TI_NOT_IMPLEMENTED } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 3b5ebb8ba42ad..5a77cc7a48d70 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -5,8 +5,8 @@ TLANG_NAMESPACE_BEGIN namespace irpass { namespace { + using StmtToOffsetMap = decltype(OffloadedResult::local_to_global_offset); -} // namespace // Break kernel into multiple parts and emit struct for listgens class Offloader { @@ -404,6 +404,74 @@ void insert_gc(IRNode *root) { } } +class AssociateContinueScope : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + using Parent = BasicStmtVisitor; + + void visit(WhileStmt *stmt) override { + auto *old_loop = cur_internal_loop_; + cur_internal_loop_ = stmt; + Parent::visit(stmt); + cur_internal_loop_ = old_loop; + } + + void visit(RangeForStmt *stmt) override { + auto *old_loop = cur_internal_loop_; + cur_internal_loop_ = stmt; + Parent::visit(stmt); + cur_internal_loop_ = old_loop; + } + + void visit(StructForStmt *stmt) override { + TI_ERROR("struct_for cannot be nested inside a kernel, stmt={}", + stmt->name()); + } + + void visit(OffloadedStmt *stmt) override { + TI_ASSERT(cur_offloaded_stmt_ == nullptr); + TI_ASSERT(cur_internal_loop_ == nullptr); + cur_offloaded_stmt_ = stmt; + Parent::visit(stmt); + cur_offloaded_stmt_ = nullptr; + } + + void visit(ContinueStmt *stmt) override { + if (stmt->scope == nullptr) { + if (cur_internal_loop_ != nullptr) { + stmt->scope = cur_internal_loop_; + } else { + stmt->scope = cur_offloaded_stmt_; + } + modified_ = true; + } + TI_ASSERT(stmt->scope != nullptr); + } + + static void run(IRNode *root) { + while (true) { + AssociateContinueScope pass; + root->accept(&pass); + if (!pass.modified_) { + break; + } + } + } + + private: + explicit AssociateContinueScope() + : modified_(false), + cur_offloaded_stmt_(nullptr), + cur_internal_loop_(nullptr) { + } + + bool modified_; + OffloadedStmt *cur_offloaded_stmt_; + Stmt *cur_internal_loop_; +}; + +} // namespace + OffloadedResult offload(IRNode *root) { OffloadedResult result; Offloader _(root); @@ -414,9 +482,10 @@ OffloadedResult offload(IRNode *root) { PromoteIntermediate::run(root, result.local_to_global_offset); PromoteLocals::run(root, result.local_to_global_offset); } - irpass::insert_gc(root); - irpass::typecheck(root); - irpass::re_id(root); + insert_gc(root); + AssociateContinueScope::run(root); + typecheck(root); + re_id(root); return result; } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 1e363bfd472d2..8e1c68961c1f1 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1066,6 +1066,10 @@ class BasicBlockSimplify : public IRVisitor { return; } + void visit(ContinueStmt *stmt) override { + return; + } + static bool is_global_write(Stmt *stmt) { return stmt->is() || stmt->is(); } diff --git a/tests/python/test_continue.py b/tests/python/test_continue.py new file mode 100644 index 0000000000000..ae533ed1e4e6d --- /dev/null +++ b/tests/python/test_continue.py @@ -0,0 +1,84 @@ +import taichi as ti + +n = 1000 + + +@ti.all_archs +def test_for_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + for j in range(n): + if j % 2 == 0: + continue + x[j] = j + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_while_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + j = 0 + while j < n: + oj = j + j += 1 + if oj % 2 == 0: + continue + x[oj] = oj + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_kernel_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + for i in range(n): + if i % 2 == 0: + # At kernel level, this is the same as return + continue + x[i] = i + + run() + xs = x.to_numpy() + for i in range(n): + expect = 0 if i % 2 == 0 else i + assert xs[i] == expect + + +@ti.all_archs +def test_unconditional_continue(): + x = ti.var(ti.i32, shape=n) + + @ti.kernel + def run(): + # Launch just one thread + for _ in range(1): + for j in range(n): + continue + # pylint: disable=unreachable + x[j] = j + + run() + xs = x.to_numpy() + for i in range(n): + assert xs[i] == 0 diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index 66bae69375a2e..d8f0963547970 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -70,22 +70,6 @@ def func(): func() -@ti.must_throw(ti.TaichiSyntaxError) -def test_continue(): - x = ti.var(ti.f32) - - @ti.layout - def layout(): - ti.root.dense(ti.i, 1).place(x) - - @ti.kernel - def func(): - while True: - continue - - func() - - @ti.must_throw(ti.TaichiSyntaxError) def test_loop_var_range(): x = ti.var(ti.f32) From 6aad0080650a9548c6d41249c32a2506771ba7a3 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 6 Apr 2020 09:22:38 -0400 Subject: [PATCH 2/7] [skip ci] enforce code format --- taichi/python/export_lang.cpp | 47 +++++++++++++----------------- taichi/transforms/make_adjoint.cpp | 21 +++++++------ 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index df3c99628d2e1..d1ffb4f882b5f 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -106,10 +106,9 @@ void export_lang(py::module &m) { m.def("reset_default_compile_config", [&]() { default_compile_config = CompileConfig(); }); - m.def( - "default_compile_config", - [&]() -> CompileConfig & { return default_compile_config; }, - py::return_value_policy::reference); + m.def("default_compile_config", + [&]() -> CompileConfig & { return default_compile_config; }, + py::return_value_policy::reference); py::class_(m, "Program") .def(py::init<>()) @@ -126,12 +125,11 @@ void export_lang(py::module &m) { return (void *)(program->get_profiler()); }) .def("finalize", &Program::finalize) - .def( - "get_root", - [&](Program *program) -> SNode * { - return program->snode_root.get(); - }, - py::return_value_policy::reference) + .def("get_root", + [&](Program *program) -> SNode * { + return program->snode_root.get(); + }, + py::return_value_policy::reference) .def("get_total_compilation_time", &Program::get_total_compilation_time) .def("print_snode_tree", &Program::print_snode_tree) .def("synchronize", &Program::synchronize); @@ -139,10 +137,9 @@ void export_lang(py::module &m) { m.def("get_current_program", get_current_program, py::return_value_policy::reference); - m.def( - "current_compile_config", - [&]() -> CompileConfig & { return get_current_program().config; }, - py::return_value_policy::reference); + m.def("current_compile_config", + [&]() -> CompileConfig & { return get_current_program().config; }, + py::return_value_policy::reference); py::class_(m, "Index").def(py::init()); py::class_(m, "SNode") @@ -172,10 +169,9 @@ void export_lang(py::module &m) { .def("data_type", [](SNode *snode) { return snode->dt; }) .def("get_num_ch", [](SNode *snode) -> int { return (int)snode->ch.size(); }) - .def( - "get_ch", - [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, - py::return_value_policy::reference) + .def("get_ch", + [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, + py::return_value_policy::reference) .def("lazy_grad", &SNode::lazy_grad) .def("read_int", &SNode::read_int) .def("read_uint", &SNode::read_uint) @@ -223,14 +219,13 @@ void export_lang(py::module &m) { py::class_(m, "Stmt"); py::class_(m, "KernelProxy") - .def( - "define", - [](Program::KernelProxy *ker, - const std::function &func) -> Kernel & { - py::gil_scoped_release release; - return ker->def(func); - }, - py::return_value_policy::reference); + .def("define", + [](Program::KernelProxy *ker, + const std::function &func) -> Kernel & { + py::gil_scoped_release release; + return ker->def(func); + }, + py::return_value_policy::reference); m.def("insert_deactivate", [](SNode *snode, const ExprGroup &indices) { return Deactivate(snode, indices); diff --git a/taichi/transforms/make_adjoint.cpp b/taichi/transforms/make_adjoint.cpp index 982933b54ba25..d195f810b308d 100644 --- a/taichi/transforms/make_adjoint.cpp +++ b/taichi/transforms/make_adjoint.cpp @@ -12,15 +12,18 @@ class ConvertLocalVar : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { TI_ASSERT(alloc->width() == 1); - bool load_only = irpass::gather_statements(alloc->parent, [&](Stmt *s) { - if (auto store = s->cast()) - return store->ptr == alloc; - else if (auto atomic = s->cast()) { - return atomic->dest == alloc; - } else { - return false; - } - }).empty(); + bool load_only = irpass::gather_statements( + alloc->parent, + [&](Stmt *s) { + if (auto store = s->cast()) + return store->ptr == alloc; + else if (auto atomic = s->cast()) { + return atomic->dest == alloc; + } else { + return false; + } + }) + .empty(); if (!load_only) { alloc->replace_with( Stmt::make(alloc->ret_type.data_type, 16)); From 717375fc326be3bcff543ebdea079045692caa29 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Mon, 6 Apr 2020 22:29:29 +0900 Subject: [PATCH 3/7] move --- taichi/codegen/codegen_llvm.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index e11b26e48d289..34061cbab088e 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -89,7 +89,10 @@ class CodeGenStmtGuard { setter_(saved_stmt_); } - protected: + CodeGenStmtGuard(CodeGenStmtGuard &&) = default; + CodeGenStmtGuard &operator=(CodeGenStmtGuard &&) = default; + + private: llvm::BasicBlock *saved_stmt_; Setter setter_; }; From 34018ad02df6bd234fc1b40e7c76d3ec171d8237 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Mon, 6 Apr 2020 23:31:40 +0900 Subject: [PATCH 4/7] [skip ci] assert != nullptr --- taichi/ir/ir.h | 1 + 1 file changed, 1 insertion(+) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index d7281062467c0..22c673f8475af 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -952,6 +952,7 @@ class ContinueStmt : public Stmt { } bool as_return() const { + TI_ASSERT(scope != nullptr); return scope->is(); } TI_STMT_DEF_FIELDS(scope); From 28933411861a8f75e2604cab566974349bf06628 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Mon, 6 Apr 2020 23:35:50 +0900 Subject: [PATCH 5/7] [skip ci] print scope --- taichi/transforms/ir_printer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index a3e03fcd31421..8f94ab16a031b 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -208,7 +208,7 @@ class IRPrinter : public IRVisitor { void visit(ContinueStmt *stmt) override { if (stmt->scope) { - print("{} continue (as_return={})", stmt->name(), stmt->as_return()); + print("{} continue (scope={})", stmt->name(), stmt->name()); } else { print("{} continue", stmt->name()); } From 0a07dd2f6f83abe95b1606b8bc7870a44f991fea Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Mon, 6 Apr 2020 23:44:08 +0900 Subject: [PATCH 6/7] [skip ci] more asserts --- taichi/ir/ir.cpp | 11 +++++++++++ taichi/ir/ir.h | 6 ++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index e6519c8ba5fb2..849cb6b0d837b 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -7,6 +7,7 @@ #include #include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" TLANG_NAMESPACE_BEGIN @@ -546,4 +547,14 @@ std::string OffloadedStmt::task_type_name(TaskType tt) { return m.find(tt)->second; } +bool ContinueStmt::as_return() const { + TI_ASSERT(scope != nullptr); + if (auto *offl = scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 22c673f8475af..8aa10ec2bb8ba 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -951,10 +951,8 @@ class ContinueStmt : public Stmt { TI_STMT_REG_FIELDS; } - bool as_return() const { - TI_ASSERT(scope != nullptr); - return scope->is(); - } + bool as_return() const; + TI_STMT_DEF_FIELDS(scope); DEFINE_ACCEPT; }; From ba50b3354ae4dc1b879d0b014820558c5f83e53b Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 7 Apr 2020 18:45:02 +0900 Subject: [PATCH 7/7] comments --- taichi/ir/ir.h | 21 +++++++++++++++++++++ taichi/transforms/offload.cpp | 2 ++ 2 files changed, 23 insertions(+) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 8aa10ec2bb8ba..8e8de394d45ae 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -951,6 +951,27 @@ class ContinueStmt : public Stmt { TI_STMT_REG_FIELDS; } + // For top-level loops, since they are parallelized to multiple threads (on + // either CPU or GPU), `continue` becomes semantically equivalent to `return`. + // + // Caveat: + // We should wrap each backend's kernel body into a function (as LLVM does). + // The reason is that, each thread may handle more than one element, + // depending on the backend's implementation. + // + // For example, CUDA uses gride-stride loops, the snippet below illustrates + // the idea: + // + // __global__ foo_kernel(...) { + // for (int i = lower; i < upper; i += gridDim) { + // auto coord = compute_coords(i); + // // run_foo_kernel is produced by codegen + // run_foo_kernel(coord); + // } + // } + // + // If run_foo_kernel() is directly inlined within foo_kernel(), `return` + // could prematurely terminate the entire kernel. bool as_return() const; TI_STMT_DEF_FIELDS(scope); diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 5a77cc7a48d70..1189632bbc03d 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -483,6 +483,8 @@ OffloadedResult offload(IRNode *root) { PromoteLocals::run(root, result.local_to_global_offset); } insert_gc(root); + // TODO(k-ye): Move this into its own pass. However, we need to wait for all + // backends to integrate with https://github.com/taichi-dev/taichi/pull/700 AssociateContinueScope::run(root); typecheck(root); re_id(root);