diff --git a/taichi/backends/wasm/codegen_wasm.cpp b/taichi/backends/wasm/codegen_wasm.cpp index f9610fdde5968..c72ec01adbff0 100644 --- a/taichi/backends/wasm/codegen_wasm.cpp +++ b/taichi/backends/wasm/codegen_wasm.cpp @@ -8,6 +8,7 @@ #include "taichi/ir/ir.h" #include "taichi/ir/statements.h" #include "taichi/util/statistics.h" +#include "taichi/util/file_sequence_writer.h" namespace taichi { namespace lang { @@ -84,16 +85,9 @@ class CodeGenLLVMWASM : public CodeGenLLVM { } void visit(OffloadedStmt *stmt) override { - stat.add("codegen_offloaded_tasks"); - TI_ASSERT(current_offload == nullptr); + TI_ASSERT(current_offload == nullptr) current_offload = stmt; using Type = OffloadedStmt::TaskType; - auto offloaded_task_name = init_offloaded_task_function(stmt); - if (prog->config.kernel_profiler && arch_is_cpu(prog->config.arch)) { - call( - builder.get(), "LLVMRuntime_profiler_start", - {get_runtime(), builder->CreateGlobalStringPtr(offloaded_task_name)}); - } if (stmt->task_type == Type::serial) { stmt->body->accept(this); } else if (stmt->task_type == Type::range_for) { @@ -101,11 +95,204 @@ class CodeGenLLVMWASM : public CodeGenLLVM { } else { TI_NOT_IMPLEMENTED } - finalize_offloaded_task_function(); - current_task->end(); - current_task = nullptr; current_offload = nullptr; } + + /** + * Extracts the original function name decorated by @ti.kernel + * + * @param kernel_name The format is defined in + * https://github.com/taichi-dev/taichi/blob/734da3f8f4439ce7f6a5337df7c54fb6dc34def8/python/taichi/lang/kernel_impl.py#L360-L362 + */ + std::string extract_original_kernel_name(const std::string &kernel_name) { + if (kernel->is_evaluator) + return kernel_name; + int pos = kernel_name.length() - 1; + int underline_count = 0; + int redundant_count = 3; + for (; pos >= 0; --pos) { + if (kernel_name.at(pos) == '_') { + underline_count += 1; + if (underline_count == redundant_count) + break; + } + } + TI_ASSERT(underline_count == redundant_count) + return kernel_name.substr(0, pos); + } + + std::string init_taichi_kernel_function() { + task_function_type = + llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), + {llvm::PointerType::get(context_ty, 0)}, false); + + auto task_kernel_name = + fmt::format("{}_body", extract_original_kernel_name(kernel_name)); + func = llvm::Function::Create(task_function_type, + llvm::Function::ExternalLinkage, + task_kernel_name, module.get()); + + for (auto &arg : func->args()) { + kernel_args.push_back(&arg); + } + kernel_args[0]->setName("context"); + + // entry_block has all the allocas + this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); + + // The real function body + func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func); + builder->SetInsertPoint(func_body_bb); + return task_kernel_name; + } + + void finalize_taichi_kernel_function() { + builder->CreateRetVoid(); + + // entry_block should jump to the body after all allocas are inserted + builder->SetInsertPoint(entry_block); + builder->CreateBr(func_body_bb); + + if (prog->config.print_kernel_llvm_ir) { + static FileSequenceWriter writer( + "taichi_kernel_generic_llvm_ir_{:04d}.ll", + "unoptimized LLVM IR (generic)"); + writer.write(module.get()); + } + TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); + } + + // This is unused + std::string create_taichi_get_root_address_function() { + auto task_function_type = + llvm::FunctionType::get(llvm::Type::getInt32Ty(*llvm_context), + {llvm::PointerType::get(context_ty, 0)}, false); + auto task_kernel_name = fmt::format("get_root_address"); + auto func = llvm::Function::Create(task_function_type, + llvm::Function::ExternalLinkage, + task_kernel_name, module.get()); + + std::vector kernel_args; + for (auto &arg : func->args()) { + kernel_args.push_back(&arg); + } + kernel_args[0]->setName("context"); + + auto entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); + auto func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func); + builder->SetInsertPoint(func_body_bb); + + // memory reserved for Context object shouldn't be polluted + llvm::Value *runtime_ptr = + create_call("Context_get_runtime", {kernel_args[0]}); + llvm::Value *runtime = builder->CreateBitCast( + runtime_ptr, + llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0)); + llvm::Value *root_ptr = create_call("LLVMRuntime_get_ptr_root", {runtime}); + llvm::Value *root_address = builder->CreatePtrToInt( + root_ptr, llvm::Type::getInt32Ty(*llvm_context)); + builder->CreateRet(root_address); + + builder->SetInsertPoint(entry_block); + builder->CreateBr(func_body_bb); + + TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); + return task_kernel_name; + } + + // Context's address is pass by kernel_args[0] which is supposed to be 0 in + // default. Runtime's address will be set to kernel_args[0] after set_root() + // call. The objects of Context and Runtime are overlapped with each other. + // + // Context Runtime Root Buffer + // +-----------+ +-------------+ +-------------+ + // |runtime* | | ... | | ... | + // |arg0 | | ... | +-------------+ + // |arg1 | |root buffer* | + // | ... | | ... | + // +-----------+ +-------------+ + std::string create_taichi_set_root_function() { + auto task_function_type = + llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), + {llvm::PointerType::get(context_ty, 0), + llvm::Type::getInt32Ty(*llvm_context)}, + false); + const std::string task_kernel_name = "set_root"; + auto func = llvm::Function::Create(task_function_type, + llvm::Function::ExternalLinkage, + task_kernel_name, module.get()); + + std::vector kernel_args; + for (auto &arg : func->args()) { + kernel_args.push_back(&arg); + } + kernel_args[0]->setName("context"); + kernel_args[1]->setName("root"); + + auto entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func); + auto func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func); + builder->SetInsertPoint(func_body_bb); + + // memory reserved for Context object shouldn't be polluted + // set runtime address to zero + llvm::Value *runtime_address_ptr = + create_call("Context_get_ptr_runtime", {kernel_args[0]}); + llvm::Value *runtime_address_val_ptr = builder->CreatePointerCast( + runtime_address_ptr, llvm::Type::getInt32PtrTy(*llvm_context)); + llvm::Value *runtime_address_val = builder->CreatePtrToInt( + kernel_args[0], llvm::Type::getInt32Ty(*llvm_context)); + builder->CreateStore(runtime_address_val, runtime_address_val_ptr); + + llvm::Value *runtime_ptr = + create_call("Context_get_runtime", {kernel_args[0]}); + llvm::Value *runtime = builder->CreateBitCast( + runtime_ptr, + llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0)); + + llvm::Value *root_base_ptr = builder->CreatePointerCast( + kernel_args[0], llvm::Type::getInt32PtrTy(*llvm_context)); + llvm::Value *root_base_val = builder->CreateLoad(root_base_ptr); + llvm::Value *root_val = builder->CreateAdd(root_base_val, kernel_args[1]); + llvm::Value *root_ptr = builder->CreateIntToPtr( + root_val, llvm::Type::getInt8PtrTy(*llvm_context)); + llvm::Value *ret_ptr = + create_call("LLVMRuntime_set_root", {runtime, root_ptr}); + builder->CreateRetVoid(); + + builder->SetInsertPoint(entry_block); + builder->CreateBr(func_body_bb); + + TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs())); + return task_kernel_name; + } + + FunctionType gen() override { + TI_AUTO_PROF + // emit_to_module + stat.add("codegen_taichi_kernel_function"); + auto offloaded_task_name = init_taichi_kernel_function(); + ir->accept(this); + finalize_taichi_kernel_function(); + + auto get_root_address_name = create_taichi_get_root_address_function(); + auto set_root_name = create_taichi_set_root_function(); + + // compile_module_to_executable + // only keep the current func + TaichiLLVMContext::eliminate_unused_functions( + module.get(), [&](std::string func_name) { + return offloaded_task_name == func_name || + get_root_address_name == func_name || + set_root_name == func_name; + }); + tlctx->add_module(std::move(module)); + auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name); + return [=](Context &context) { + TI_TRACE("Launching Taichi Kernel Function"); + auto func = (int32(*)(void *))kernel_symbol; + func(&context); + }; + } }; FunctionType CodeGenWASM::codegen() {