Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASM] Add set_root() for each WASM module #2429

Merged
merged 13 commits into from
Jun 24, 2021
Merged
209 changes: 198 additions & 11 deletions taichi/backends/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/util/statistics.h"
#include "taichi/util/file_sequence_writer.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -84,28 +85,214 @@ class CodeGenLLVMWASM : public CodeGenLLVM {
}

void visit(OffloadedStmt *stmt) override {
stat.add("codegen_offloaded_tasks");
TI_ASSERT(current_offload == nullptr);
TI_ASSERT(current_offload == nullptr)
current_offload = stmt;
using Type = OffloadedStmt::TaskType;
auto offloaded_task_name = init_offloaded_task_function(stmt);
if (prog->config.kernel_profiler && arch_is_cpu(prog->config.arch)) {
call(
builder.get(), "LLVMRuntime_profiler_start",
{get_runtime(), builder->CreateGlobalStringPtr(offloaded_task_name)});
}
if (stmt->task_type == Type::serial) {
stmt->body->accept(this);
} else if (stmt->task_type == Type::range_for) {
create_offload_range_for(stmt);
} else {
TI_NOT_IMPLEMENTED
}
finalize_offloaded_task_function();
current_task->end();
current_task = nullptr;
current_offload = nullptr;
}

/**
* Extracts the original function name decorated by @ti.kernel
*
* @param kernel_name The format is defined in
* https://github.com/taichi-dev/taichi/blob/734da3f8f4439ce7f6a5337df7c54fb6dc34def8/python/taichi/lang/kernel_impl.py#L360-L362
*/
std::string extract_original_kernel_name(const std::string &kernel_name) {
if (kernel->is_evaluator)
return kernel_name;
int pos = kernel_name.length() - 1;
int underline_count = 0;
int redundant_count = 3;
for (; pos >= 0; --pos) {
if (kernel_name.at(pos) == '_') {
underline_count += 1;
if (underline_count == redundant_count)
break;
}
}
TI_ASSERT(underline_count == redundant_count)
return kernel_name.substr(0, pos);
}

std::string init_taichi_kernel_function() {
task_function_type =
llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
{llvm::PointerType::get(context_ty, 0)}, false);

auto task_kernel_name =
fmt::format("{}_body", extract_original_kernel_name(kernel_name));
func = llvm::Function::Create(task_function_type,
llvm::Function::ExternalLinkage,
task_kernel_name, module.get());

for (auto &arg : func->args()) {
kernel_args.push_back(&arg);
}
kernel_args[0]->setName("context");

// entry_block has all the allocas
this->entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);

// The real function body
func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
builder->SetInsertPoint(func_body_bb);
return task_kernel_name;
}

void finalize_taichi_kernel_function() {
builder->CreateRetVoid();

// entry_block should jump to the body after all allocas are inserted
builder->SetInsertPoint(entry_block);
builder->CreateBr(func_body_bb);

if (prog->config.print_kernel_llvm_ir) {
static FileSequenceWriter writer(
"taichi_kernel_generic_llvm_ir_{:04d}.ll",
"unoptimized LLVM IR (generic)");
writer.write(module.get());
}
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
}

// This is unused
std::string create_taichi_get_root_address_function() {
squarefk marked this conversation as resolved.
Show resolved Hide resolved
auto task_function_type =
llvm::FunctionType::get(llvm::Type::getInt32Ty(*llvm_context),
{llvm::PointerType::get(context_ty, 0)}, false);
auto task_kernel_name = fmt::format("get_root_address");
auto func = llvm::Function::Create(task_function_type,
llvm::Function::ExternalLinkage,
task_kernel_name, module.get());

std::vector<llvm::Value *> kernel_args;
for (auto &arg : func->args()) {
kernel_args.push_back(&arg);
}
kernel_args[0]->setName("context");

auto entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);
auto func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
builder->SetInsertPoint(func_body_bb);

// memory reserved for Context object shouldn't be polluted
llvm::Value *runtime_ptr =
create_call("Context_get_runtime", {kernel_args[0]});
llvm::Value *runtime = builder->CreateBitCast(
runtime_ptr,
llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0));
llvm::Value *root_ptr = create_call("LLVMRuntime_get_ptr_root", {runtime});
llvm::Value *root_address = builder->CreatePtrToInt(
root_ptr, llvm::Type::getInt32Ty(*llvm_context));
builder->CreateRet(root_address);

builder->SetInsertPoint(entry_block);
builder->CreateBr(func_body_bb);

TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
return task_kernel_name;
}

// Context's address is pass by kernel_args[0] which is supposed to be 0 in
// default. Runtime's address will be set to kernel_args[0] after set_root()
// call. The objects of Context and Runtime are overlapped with each other.
//
// Context Runtime Root Buffer
// +-----------+ +-------------+ +-------------+
// |runtime* | | ... | | ... |
// |arg0 | | ... | +-------------+
// |arg1 | |root buffer* |
// | ... | | ... |
// +-----------+ +-------------+
std::string create_taichi_set_root_function() {
squarefk marked this conversation as resolved.
Show resolved Hide resolved
auto task_function_type =
llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context),
{llvm::PointerType::get(context_ty, 0),
llvm::Type::getInt32Ty(*llvm_context)},
false);
const std::string task_kernel_name = "set_root";
auto func = llvm::Function::Create(task_function_type,
llvm::Function::ExternalLinkage,
task_kernel_name, module.get());

std::vector<llvm::Value *> kernel_args;
for (auto &arg : func->args()) {
kernel_args.push_back(&arg);
}
kernel_args[0]->setName("context");
kernel_args[1]->setName("root");

auto entry_block = llvm::BasicBlock::Create(*llvm_context, "entry", func);
auto func_body_bb = llvm::BasicBlock::Create(*llvm_context, "body", func);
builder->SetInsertPoint(func_body_bb);

// memory reserved for Context object shouldn't be polluted
// set runtime address to zero
llvm::Value *runtime_address_ptr =
create_call("Context_get_ptr_runtime", {kernel_args[0]});
llvm::Value *runtime_address_val_ptr = builder->CreatePointerCast(
runtime_address_ptr, llvm::Type::getInt32PtrTy(*llvm_context));
llvm::Value *runtime_address_val = builder->CreatePtrToInt(
kernel_args[0], llvm::Type::getInt32Ty(*llvm_context));
builder->CreateStore(runtime_address_val, runtime_address_val_ptr);

llvm::Value *runtime_ptr =
create_call("Context_get_runtime", {kernel_args[0]});
llvm::Value *runtime = builder->CreateBitCast(
runtime_ptr,
llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0));

llvm::Value *root_base_ptr = builder->CreatePointerCast(
kernel_args[0], llvm::Type::getInt32PtrTy(*llvm_context));
llvm::Value *root_base_val = builder->CreateLoad(root_base_ptr);
llvm::Value *root_val = builder->CreateAdd(root_base_val, kernel_args[1]);
llvm::Value *root_ptr = builder->CreateIntToPtr(
root_val, llvm::Type::getInt8PtrTy(*llvm_context));
llvm::Value *ret_ptr =
create_call("LLVMRuntime_set_root", {runtime, root_ptr});
builder->CreateRetVoid();

builder->SetInsertPoint(entry_block);
builder->CreateBr(func_body_bb);

TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
return task_kernel_name;
}

FunctionType gen() override {
TI_AUTO_PROF
// emit_to_module
stat.add("codegen_taichi_kernel_function");
auto offloaded_task_name = init_taichi_kernel_function();
ir->accept(this);
finalize_taichi_kernel_function();

auto get_root_address_name = create_taichi_get_root_address_function();
auto set_root_name = create_taichi_set_root_function();

// compile_module_to_executable
// only keep the current func
TaichiLLVMContext::eliminate_unused_functions(
module.get(), [&](std::string func_name) {
return offloaded_task_name == func_name ||
get_root_address_name == func_name ||
set_root_name == func_name;
});
tlctx->add_module(std::move(module));
auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name);
return [=](Context &context) {
TI_TRACE("Launching Taichi Kernel Function");
auto func = (int32(*)(void *))kernel_symbol;
func(&context);
};
}
};

FunctionType CodeGenWASM::codegen() {
Expand Down