Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][CPU] Move the runtime CPU block loop into codegen #6847

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,27 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
// The loop body
llvm::Function *body;
{
// auto guard = get_function_creation_guard(
// {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
// llvm::Type::getInt8PtrTy(*llvm_context),
// tlctx->get_data_type<int>()});

// auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
// loop_vars_llvm[stmt].push_back(loop_var);
// builder->CreateStore(get_arg(2), loop_var);
// stmt->body->accept(this);
auto guard = get_function_creation_guard(
{llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0),
llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt8PtrTy(*llvm_context), tlctx->get_data_type<int>(),
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(2), loop_var);
stmt->body->accept(this);
auto begin_var = builder->CreateAlloca(
tlctx->get_data_type(PrimitiveType::i32), (unsigned)0, nullptr);
auto end_var = builder->CreateAlloca(
tlctx->get_data_type(PrimitiveType::i32), (unsigned)0, nullptr);
builder->CreateStore(get_arg(2), begin_var);
builder->CreateStore(get_arg(3), end_var);
create_cpu_block_range_for(stmt, begin_var, end_var);

body = guard.body;
}
Expand All @@ -55,15 +67,15 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {

auto [begin, end] = get_range_for_bounds(stmt);

// adaptive block_dim
if (prog->this_thread_config().cpu_block_dim_adaptive) {
int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step);
int num_threads = stmt->num_cpu_threads;
int items_per_thread = std::max(1, num_items / (num_threads * 32));
// keep each task has at least 512 items to amortize scheduler overhead
// also saturate the value to 1024 for better load balancing
stmt->block_dim = std::min(1024, std::max(512, items_per_thread));
}
// // adaptive block_dim
// if (prog->this_thread_config().cpu_block_dim_adaptive) {
// int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step);
// int num_threads = stmt->num_cpu_threads;
// int items_per_thread = std::max(1, num_items / (num_threads * 32));
// // keep each task has at least 512 items to amortize scheduler overhead
// // also saturate the value to 1024 for better load balancing
// stmt->block_dim = std::min(1024, std::max(512, items_per_thread));
// }

call("cpu_parallel_range_for", get_arg(0),
tlctx->get_constant(stmt->num_cpu_threads), begin, end,
Expand Down
50 changes: 50 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,56 @@ void TaskCodeGenLLVM::create_increment(llvm::Value *ptr, llvm::Value *value) {
builder->CreateStore(builder->CreateAdd(original_value, value), ptr);
}

void TaskCodeGenLLVM::create_cpu_block_range_for(OffloadedStmt *stmt,
llvm::Value *begin_var,
llvm::Value *end_var) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
BasicBlock *loop_inc =
BasicBlock::Create(*llvm_context, "for_loop_inc", func);
BasicBlock *after_loop = BasicBlock::Create(*llvm_context, "after_for", func);
BasicBlock *loop_test =
BasicBlock::Create(*llvm_context, "for_loop_test", func);

auto loop_var_ty = tlctx->get_data_type(PrimitiveType::i32);
auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(builder->CreateLoad(loop_var_ty, begin_var), loop_var);

builder->CreateBr(loop_test);
{
// test block
builder->SetInsertPoint(loop_test);
llvm::Value *cond;
cond = builder->CreateICmp(llvm::CmpInst::Predicate::ICMP_SLT,
builder->CreateLoad(loop_var_ty, loop_var),
builder->CreateLoad(loop_var_ty, end_var));
// loop_var, end_var);
builder->CreateCondBr(cond, body, after_loop);
}
{
{
auto lrg = make_loop_reentry_guard(this);
// The continue stmt should jump to the loop-increment block!
current_loop_reentry = loop_inc;
// body cfg
builder->SetInsertPoint(body);
stmt->body->accept(this);
}
if (!returned) {
builder->CreateBr(loop_inc);
} else {
returned = false;
}
builder->SetInsertPoint(loop_inc);

create_increment(loop_var, tlctx->get_constant(1));
builder->CreateBr(loop_test);
}
// next cfg
builder->SetInsertPoint(after_loop);
}

void TaskCodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) {
using namespace llvm;
BasicBlock *body = BasicBlock::Create(*llvm_context, "for_loop_body", func);
Expand Down
4 changes: 4 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
// Direct translation
void create_naive_range_for(RangeForStmt *for_stmt);

void create_cpu_block_range_for(OffloadedStmt *stmt,
llvm::Value *begin_var,
llvm::Value *end_var);

static std::string get_runtime_snode_name(SNode *snode);

void visit(Block *stmt_list) override;
Expand Down
39 changes: 22 additions & 17 deletions taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ using host_vsnprintf_type = int (*)(char *,
const char *,
std::va_list);
using vm_allocator_type = void *(*)(void *, std::size_t, std::size_t);
using RangeForTaskFunc = void(RuntimeContext *, const char *tls, int i);
using RangeForTaskFunc = void(RuntimeContext *,
const char *tls,
int block_begin,
int block_end);
using MeshForTaskFunc = void(RuntimeContext *, const char *tls, uint32_t i);
using parallel_for_type = void (*)(void *thread_pool,
int splits,
Expand Down Expand Up @@ -1466,14 +1469,16 @@ void cpu_parallel_range_for_task(void *range_context,
if (ctx.step == 1) {
int block_start = ctx.begin + task_id * ctx.block_size;
int block_end = std::min(block_start + ctx.block_size, ctx.end);
for (int i = block_start; i < block_end; i++) {
ctx.body(&this_thread_context, tls_ptr, i);
}
// for (int i = block_start; i < block_end; i++) {
// ctx.body(&this_thread_context, tls_ptr, i);
// }
printf("@@@@@ block_start %d, block_end %d\n", block_start, block_end);
ctx.body(&this_thread_context, tls_ptr, block_start, block_end);
} else if (ctx.step == -1) {
int block_start = ctx.end - task_id * ctx.block_size;
int block_end = std::max(ctx.begin, block_start * ctx.block_size);
for (int i = block_start - 1; i >= block_end; i--) {
ctx.body(&this_thread_context, tls_ptr, i);
// ctx.body(&this_thread_context, tls_ptr, i);
}
}
if (ctx.epilogue)
Expand Down Expand Up @@ -1517,17 +1522,17 @@ void gpu_parallel_range_for(RuntimeContext *context,
RangeForTaskFunc *func,
range_for_xlogue epilogue,
const std::size_t tls_size) {
int idx = thread_idx() + block_dim() * block_idx() + begin;
alignas(8) char tls_buffer[tls_size];
auto tls_ptr = &tls_buffer[0];
if (prologue)
prologue(context, tls_ptr);
while (idx < end) {
func(context, tls_ptr, idx);
idx += block_dim() * grid_dim();
}
if (epilogue)
epilogue(context, tls_ptr);
// int idx = thread_idx() + block_dim() * block_idx() + begin;
// alignas(8) char tls_buffer[tls_size];
// auto tls_ptr = &tls_buffer[0];
// if (prologue)
// prologue(context, tls_ptr);
// while (idx < end) {
// func(context, tls_ptr, idx);
// idx += block_dim() * grid_dim();
// }
// if (epilogue)
// epilogue(context, tls_ptr);
}

struct mesh_task_helper_context {
Expand Down Expand Up @@ -1556,7 +1561,7 @@ void cpu_parallel_mesh_for_task(void *range_context,
for (int idx = block_start; idx < block_end; idx++) {
if (ctx.prologue)
ctx.prologue(ctx.context, tls_ptr, idx);
ctx.body(&this_thread_context, tls_ptr, idx);
// ctx.body(&this_thread_context, tls_ptr, idx);
if (ctx.epilogue)
ctx.epilogue(ctx.context, tls_ptr, idx);
}
Expand Down