Skip to content

Commit

Permalink
add link impl
Browse files Browse the repository at this point in the history
  • Loading branch information
galeselee committed Jan 16, 2023
1 parent 8192884 commit 2620a53
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
81 changes: 81 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,32 @@ std::unique_ptr<llvm::Module> TaichiLLVMContext::module_from_file(
// runtime_module->print(llvm::errs(), nullptr);
}

#ifdef TI_WITH_AMDGPU
auto patch_amdgpu_kernel_dim = [&](std::string name,
llvm::Value *lhs) {
std::string actual_name;
if (name == "block_dim")
actual_name = "__ockl_get_local_size";
else if (name == "grid_dim")
actual_name = "__ockl_get_num_groups";
else
TI_ERROR("Unknown patch function name");
auto func = module->getFunction(name);
auto actual_func = module->getFunction(actual_name);
if (!func || !actual_func) {
return;
}
func->deleteBody();
auto bb = llvm::BasicBlock::Create(*ctx, "entry", func);
IRBuilder<> builder(*ctx);
builder.SetInsertPoint(bb);
auto dim_ = builder.CreateCall(actual_func->getFunctionType(), actual_func, {lhs});
auto ret_ = builder.CreateTrunc(dim_, llvm::Type::getInt32Ty(*ctx));
builder.CreateRet(ret_);
TaichiLLVMContext::mark_inline(func);
};
#endif

if (arch_ == Arch::amdgpu) {
module->setTargetTriple("amdgcn-amd-amdhsa");
#ifdef TI_WITH_AMDGPU
Expand All @@ -498,6 +524,10 @@ std::unique_ptr<llvm::Module> TaichiLLVMContext::module_from_file(
function_pass_manager.doFinalization();
patch_intrinsic("thread_idx", llvm::Intrinsic::amdgcn_workitem_id_x);
patch_intrinsic("block_idx", llvm::Intrinsic::amdgcn_workgroup_id_x);

link_module_with_amdgpu_libdevice(module);
patch_dim("block_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
patch_dim("grid_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
#endif
}
}
Expand Down Expand Up @@ -537,6 +567,57 @@ void TaichiLLVMContext::link_module_with_cuda_libdevice(
}
}

void TaichiLLVMContext:::link_module_with_amdgpu_libdevice(
std::unique_ptr<llvm::Module> &module) {
TI_ASSERT(arch_ == Arch::amdgpu);
auto isa_version = AMDGPUContext::get_instance().get_mcpu().substr(3,4);
std::string libdevice_files[] = {
"ocml.bc",
"oclc_wavefrontsize64_off.bc",
"ockl.bc",
"oclc_abi_version_400.bc",
"oclc_correctly_rounded_sqrt_off.bc",
"oclc_daz_opt_off.bc",
"oclc_finite_only_off.bc",
"oclc_isa_version_" + isa_version + ".bc",
"oclc_unsafe_math_off.bc",
"opencl.bc"
};

for (auto &libdevice : libdevice_files) {
std::string lib_dir = runtime_lib_dir() + "/";
auto libdevice_module = module_from_bitcode_file(lib_dir + libdevice,
get_this_thread_context());

if (libdevice == "ocml.bc")
module->setDataLayout(libdevice_module->getDataLayout());

std::vector<std::string> libdevice_func_names;
for (auto &f : *libdevice_module) {
if (!f.isDeclaration()) {
libdevice_function_names.push_back(f.getName().str());
}
}

for (auto &f : libdevice_module->functions()) {
auto func_ = module->getFunction(f.getName());
if (!func_ && starts_with(f.getName().lower(), "__" + libdevice))
f.setLinkage(llvm::Function::CommonLinkage);`
}

bool failed = llvm::Linker::linkModules(*module, std::move(libdevice_module));
if (failed) {
TI_ERROR("AMDGPU libdevice linking failure.");
}

for (auto func_name : libdevice_function_names) {
auto func = module->getFunction(func_name);
if (func)
func->setLinkage(llvm::Function::InternalLinkage);
}
}
}

void TaichiLLVMContext::add_struct_module(std::unique_ptr<Module> module,
int tree_id) {
TI_AUTO_PROF;
Expand Down
2 changes: 2 additions & 0 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class TaichiLLVMContext {

void link_module_with_cuda_libdevice(std::unique_ptr<llvm::Module> &module);

void link_module_with_amdgpu_libdevice(std::unique_ptr<llvm::Module> &module);

static int num_instructions(llvm::Function *func);

void insert_nvvm_annotation(llvm::Function *func, std::string key, int val);
Expand Down

0 comments on commit 2620a53

Please sign in to comment.