Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[amdgpu] Enable struct for on amdgpu backend #7247

Merged
merged 21 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/assets
14 changes: 13 additions & 1 deletion taichi/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
} else if (stmt->task_type == Type::range_for) {
create_offload_range_for(stmt);
} else if (stmt->task_type == Type::struct_for) {
create_offload_struct_for(stmt, true);
create_offload_struct_for(stmt);
} else if (stmt->task_type == Type::mesh_for) {
create_offload_mesh_for(stmt);
} else if (stmt->task_type == Type::listgen) {
Expand Down Expand Up @@ -395,6 +395,18 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
}
}
}

private:
std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
auto thread_idx =
builder->CreateIntrinsic(Intrinsic::amdgcn_workitem_id_x, {}, {});
auto workgroup_dim_ = call(
"__ockl_get_local_size",
llvm::ConstantInt::get(llvm::Type::getInt32Ty(*llvm_context), 0));
auto block_dim = builder->CreateTrunc(workgroup_dim_,
llvm::Type::getInt32Ty(*llvm_context));
return std::make_tuple(thread_idx, block_dim);
}
};

LLVMCompiledTask KernelCodeGenAMDGPU::compile_task(
Expand Down
7 changes: 7 additions & 0 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
TI_NOT_IMPLEMENTED
}
}

private:
std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
auto thread_idx = tlctx->get_constant(0);
auto block_dim = tlctx->get_constant(1);
return std::make_tuple(thread_idx, block_dim);
}
};

} // namespace
Expand Down
11 changes: 10 additions & 1 deletion taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
} else if (stmt->task_type == Type::range_for) {
create_offload_range_for(stmt);
} else if (stmt->task_type == Type::struct_for) {
create_offload_struct_for(stmt, true);
create_offload_struct_for(stmt);
} else if (stmt->task_type == Type::mesh_for) {
create_offload_mesh_for(stmt);
} else if (stmt->task_type == Type::listgen) {
Expand Down Expand Up @@ -584,6 +584,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
}
}

private:
std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
auto thread_idx =
builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {});
auto block_dim = builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x,
{}, {});
return std::make_tuple(thread_idx, block_dim);
}
};

LLVMCompiledTask KernelCodeGenCUDA::compile_task(
Expand Down
7 changes: 7 additions & 0 deletions taichi/codegen/dx12/codegen_dx12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {
TI_NOT_IMPLEMENTED
}
}

private:
std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
auto thread_idx = tlctx->get_constant(0);
auto block_dim = tlctx->get_constant(1);
return std::make_tuple(thread_idx, block_dim);
}
};

} // namespace
Expand Down
24 changes: 5 additions & 19 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,8 +2023,7 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
return std::tuple(begin, end);
}

void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
bool spmd) {
void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) {
using namespace llvm;
// TODO: instead of constructing tons of LLVM IR, writing the logic in
// runtime.cpp may be a cleaner solution. See
Expand Down Expand Up @@ -2124,18 +2123,9 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
call("block_barrier"); // "__syncthreads()"
}

llvm::Value *thread_idx = nullptr, *block_dim = nullptr;

if (spmd) {
thread_idx =
builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {});
block_dim = builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x,
{}, {});
builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound),
loop_index);
} else {
builder->CreateStore(lower_bound, loop_index);
}
auto [thread_idx, block_dim] = this->get_spmd_info();
builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound),
loop_index);

auto loop_test_bb = BasicBlock::Create(*llvm_context, "loop_test", func);
auto loop_body_bb = BasicBlock::Create(*llvm_context, "loop_body", func);
Expand Down Expand Up @@ -2218,11 +2208,7 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
// body tail: increment loop_index and jump to loop_test
builder->SetInsertPoint(body_tail_bb);

if (spmd) {
create_increment(loop_index, block_dim);
} else {
create_increment(loop_index, tlctx->get_constant(1));
}
create_increment(loop_index, block_dim);
builder->CreateBr(loop_test_bb);

builder->SetInsertPoint(func_exit);
Expand Down
5 changes: 4 additions & 1 deletion taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
TI_NOT_IMPLEMENTED;
}

void create_offload_struct_for(OffloadedStmt *stmt, bool spmd = false);
void create_offload_struct_for(OffloadedStmt *stmt);

void visit(LoopIndexStmt *stmt) override;

Expand Down Expand Up @@ -410,6 +410,9 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
const Type *current_type,
int &current_element,
std::vector<llvm::Value *> &current_index);

virtual std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() = 0;

};

} // namespace taichi::lang
Expand Down
5 changes: 5 additions & 0 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

#if defined(TI_WITH_AMDGPU)
#include "llvm/IR/IntrinsicsAMDGPU.h"
#endif

#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
Expand Down
5 changes: 5 additions & 0 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {
res.module = std::move(this->module);
return res;
}

private:
std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() override {
TI_NOT_IMPLEMENTED;
}
};

FunctionType KernelCodeGenWASM::compile_to_function() {
Expand Down
60 changes: 14 additions & 46 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@
#include "llvm_context.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#include "taichi/codegen/codegen_utils.h"
#ifdef TI_WITH_AMDGPU

#include "taichi/runtime/llvm/llvm_context_pass.h"
#endif

#ifdef _WIN32
// Travis CI seems doesn't support <filesystem>...
Expand Down Expand Up @@ -1033,52 +1032,21 @@ void TaichiLLVMContext::add_struct_for_func(llvm::Module *module,
if (module->getFunction(func_name)) {
return;
}
auto struct_for_func = module->getFunction("parallel_struct_for");
auto &llvm_context = module->getContext();
auto value_map = llvm::ValueToValueMapTy();
auto patched_struct_for_func =
llvm::CloneFunction(struct_for_func, value_map);
patched_struct_for_func->setName(func_name);

int num_found_alloca = 0;
llvm::AllocaInst *alloca = nullptr;

auto char_type = llvm::Type::getInt8Ty(llvm_context);

// Find the "1" in "char tls_buffer[1]" and replace it with
// "tls_buffer_size"
for (auto &bb : *patched_struct_for_func) {
for (llvm::Instruction &inst : bb) {
auto now_alloca = llvm::dyn_cast<AllocaInst>(&inst);
if (!now_alloca || now_alloca->getAlign().value() != 8)
continue;
auto alloca_type = now_alloca->getAllocatedType();
// Allocated type should be array [1 x i8]
if (alloca_type->isArrayTy() && alloca_type->getArrayNumElements() == 1 &&
alloca_type->getArrayElementType() == char_type) {
alloca = now_alloca;
num_found_alloca++;
}
}
}
// There should be **exactly** one replacement.
TI_ASSERT(num_found_alloca == 1 && alloca);
auto new_type = llvm::ArrayType::get(char_type, tls_size);
{
llvm::IRBuilder<> builder(alloca);
auto *new_alloca = builder.CreateAlloca(new_type);
new_alloca->setAlignment(Align(8));
TI_ASSERT(alloca->hasOneUse());
auto *gep = llvm::cast<llvm::GetElementPtrInst>(alloca->user_back());
TI_ASSERT(gep->getPointerOperand() == alloca);
std::vector<Value *> indices(gep->idx_begin(), gep->idx_end());
builder.SetInsertPoint(gep);
auto *new_gep = builder.CreateInBoundsGEP(new_type, new_alloca, indices);
gep->replaceAllUsesWith(new_gep);
gep->eraseFromParent();
alloca->eraseFromParent();
llvm::legacy::PassManager module_pass_manager;
if (config_.arch == Arch::amdgpu) {
#ifdef TI_WITH_AMDGPU
module_pass_manager.add(
new AMDGPUAddStructForFuncPass(func_name, tls_size));
module_pass_manager.run(*module);
#else
TI_NOT_IMPLEMENTED
#endif
} else {
module_pass_manager.add(new AddStructForFuncPass(func_name, tls_size));
module_pass_manager.run(*module);
}
}

std::string TaichiLLVMContext::get_struct_for_func_name(int tls_size) {
return "parallel_struct_for_" + std::to_string(tls_size);
}
Expand Down
122 changes: 122 additions & 0 deletions taichi/runtime/llvm/llvm_context_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include "llvm/Transforms/Utils/Cloning.h"

#if defined(TI_WITH_AMDGPU)
#include "taichi/rhi/amdgpu/amdgpu_context.h"
Expand All @@ -18,6 +20,63 @@
namespace taichi {
namespace lang {
using namespace llvm;

struct AddStructForFuncPass : public ModulePass {
static inline char ID{0};
std::string func_name_;
int tls_size_;
AddStructForFuncPass(std::string func_name, int tls_size) : ModulePass(ID) {
func_name_ = func_name;
tls_size_ = tls_size;
}
bool runOnModule(llvm::Module &M) override {
auto struct_for_func = M.getFunction("parallel_struct_for");
auto &llvm_context = M.getContext();
auto value_map = llvm::ValueToValueMapTy();
auto patched_struct_for_func =
llvm::CloneFunction(struct_for_func, value_map);
patched_struct_for_func->setName(func_name_);

int num_found_alloca = 0;
llvm::AllocaInst *alloca = nullptr;

auto char_type = llvm::Type::getInt8Ty(llvm_context);

// Find the "1" in "char tls_buffer[1]" and replace it with
// "tls_buffer_size"
for (auto &bb : *patched_struct_for_func) {
for (llvm::Instruction &inst : bb) {
auto now_alloca = llvm::dyn_cast<AllocaInst>(&inst);
if (!now_alloca || now_alloca->getAlign().value() != 8)
continue;
auto alloca_type = now_alloca->getAllocatedType();
// Allocated type should be array [1 x i8]
if (alloca_type->isArrayTy() &&
alloca_type->getArrayNumElements() == 1 &&
alloca_type->getArrayElementType() == char_type) {
alloca = now_alloca;
num_found_alloca++;
}
}
}
TI_ASSERT(num_found_alloca == 1 && alloca);
auto new_type = llvm::ArrayType::get(char_type, tls_size_);
llvm::IRBuilder<> builder(alloca);
auto *new_alloca = builder.CreateAlloca(new_type);
new_alloca->setAlignment(Align(8));
TI_ASSERT(alloca->hasOneUse());
auto *gep = llvm::cast<llvm::GetElementPtrInst>(alloca->user_back());
TI_ASSERT(gep->getPointerOperand() == alloca);
std::vector<Value *> indices(gep->idx_begin(), gep->idx_end());
builder.SetInsertPoint(gep);
auto *new_gep = builder.CreateInBoundsGEP(new_type, new_alloca, indices);
gep->replaceAllUsesWith(new_gep);
gep->eraseFromParent();
alloca->eraseFromParent();
return false;
}
};

#if defined(TI_WITH_AMDGPU)
struct AMDGPUConvertAllocaInstAddressSpacePass : public FunctionPass {
static inline char ID{0};
Expand Down Expand Up @@ -52,6 +111,69 @@ struct AMDGPUConvertAllocaInstAddressSpacePass : public FunctionPass {
}
};

struct AMDGPUAddStructForFuncPass : public ModulePass {
static inline char ID{0};
std::string func_name_;
int tls_size_;
AMDGPUAddStructForFuncPass(std::string func_name, int tls_size)
: ModulePass(ID) {
func_name_ = func_name;
tls_size_ = tls_size;
}
bool runOnModule(llvm::Module &M) override {
auto struct_for_func = M.getFunction("parallel_struct_for");
auto &llvm_context = M.getContext();
auto value_map = llvm::ValueToValueMapTy();
auto patched_struct_for_func =
llvm::CloneFunction(struct_for_func, value_map);
patched_struct_for_func->setName(func_name_);

int num_found_alloca = 0;
llvm::AllocaInst *alloca = nullptr;

auto char_type = llvm::Type::getInt8Ty(llvm_context);

// Find the "1" in "char tls_buffer[1]" and replace it with
// "tls_buffer_size"
for (auto &bb : *patched_struct_for_func) {
for (llvm::Instruction &inst : bb) {
auto now_alloca = llvm::dyn_cast<AllocaInst>(&inst);
if (!now_alloca || now_alloca->getAlign().value() != 8)
continue;
auto alloca_type = now_alloca->getAllocatedType();
// Allocated type should be array [1 x i8]
if (alloca_type->isArrayTy() &&
alloca_type->getArrayNumElements() == 1 &&
alloca_type->getArrayElementType() == char_type) {
alloca = now_alloca;
num_found_alloca++;
}
}
}
TI_ASSERT(num_found_alloca == 1 && alloca);
auto new_type = llvm::ArrayType::get(char_type, tls_size_);
llvm::IRBuilder<> builder(alloca);
auto *new_alloca = builder.CreateAlloca(new_type, (unsigned)5);
new_alloca->setAlignment(Align(8));
auto new_ty = llvm::PointerType::get(new_type, unsigned(0));
auto *new_cast = builder.CreateAddrSpaceCast(new_alloca, new_ty);
new_alloca->setAlignment(Align(8));
TI_ASSERT(alloca->hasOneUse());
auto *cast = llvm::cast<llvm::AddrSpaceCastInst>(alloca->user_back());
TI_ASSERT(cast->hasOneUse());
auto *gep = llvm::cast<llvm::GetElementPtrInst>(cast->user_back());
TI_ASSERT(gep->getPointerOperand() == cast);
std::vector<Value *> indices(gep->idx_begin(), gep->idx_end());
builder.SetInsertPoint(gep);
auto *new_gep = builder.CreateInBoundsGEP(new_type, new_cast, indices);
gep->replaceAllUsesWith(new_gep);
gep->eraseFromParent();
cast->eraseFromParent();
alloca->eraseFromParent();
return false;
}
};

struct AMDGPUConvertFuncParamAddressSpacePass : public ModulePass {
static inline char ID{0};
AMDGPUConvertFuncParamAddressSpacePass() : ModulePass(ID) {
Expand Down