Skip to content

Commit

Permalink
[refactor] Remove SNodeAttr and decouple SNode from LLVM (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Apr 19, 2020
1 parent 1c9f669 commit d5ce88b
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 81 deletions.
38 changes: 19 additions & 19 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "codegen_llvm.h"

#include "taichi/struct/struct_llvm.h"

TLANG_NAMESPACE_BEGIN

// TODO: sort function definitions to match declaration order in header
Expand Down Expand Up @@ -206,14 +208,6 @@ std::unique_ptr<RuntimeObject> CodeGenLLVM::emit_struct_meta_object(
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED;
}
if (false) {
// auto ptr_type = llvm::Type::getInt8PtrTy(*llvm_context, 0);
auto ptr_type = llvm::PointerType::get(meta->type, 0);
auto ptr = meta->ptr; // builder->CreatePointerCast(meta->ptr, ptr_type);
auto struct_meta_size = tlctx->get_type_size(meta->type);
builder->CreateIntrinsic(llvm::Intrinsic::invariant_start, {ptr_type},
{tlctx->get_constant(struct_meta_size), ptr});
}
return meta;
}

Expand All @@ -223,13 +217,17 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,
RuntimeObject common("StructMeta", this, builder.get(), node_meta);
std::size_t element_size;
if (snode->type == SNodeType::dense) {
auto element_ty = snode_attr[snode].llvm_body_type->getArrayElementType();
auto body_type =
StructCompilerLLVM::get_llvm_body_type(module.get(), snode);
auto element_ty = body_type->getArrayElementType();
element_size = tlctx->get_type_size(element_ty);
} else if (snode->type == SNodeType::pointer) {
auto element_ty = tlctx->snode_attr[snode->ch[0]].llvm_type;
auto element_ty = StructCompilerLLVM::get_llvm_node_type(
module.get(), snode->ch[0].get());
element_size = tlctx->get_type_size(element_ty);
} else {
auto element_ty = tlctx->snode_attr[snode].llvm_element_type;
auto element_ty =
StructCompilerLLVM::get_llvm_element_type(module.get(), snode);
element_size = tlctx->get_type_size(element_ty);
}
common.set("snode_id", tlctx->get_constant(snode->id));
Expand Down Expand Up @@ -266,13 +264,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,
}

CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
// TODO: simplify ModuleBuilder ctor input
: ModuleBuilder(kernel->program.get_llvm_context(kernel->arch)
->clone_struct_module()),
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(kernel->program.get_llvm_context(kernel->arch)
->clone_struct_module()),
kernel(kernel),
ir(ir),
prog(&kernel->program),
snode_attr(prog->get_llvm_context(kernel->arch)->snode_attr) {
prog(&kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir;
initialize_context();
Expand Down Expand Up @@ -1117,8 +1114,9 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,

void CodeGenLLVM::visit(GetRootStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
get_root(),
PointerType::get(snode_attr[prog->snode_root.get()].llvm_type, 0));
get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
}

void CodeGenLLVM::visit(OffsetAndExtractBitsStmt *stmt) {
Expand Down Expand Up @@ -1184,7 +1182,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
{builder->CreateBitCast(llvm_val[stmt->input_ptr],
PointerType::getInt8PtrTy(*llvm_context))});
llvm_val[stmt] = builder->CreateBitCast(
ch, PointerType::get(snode_attr[stmt->output_snode].llvm_type, 0));
ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
Expand Down
5 changes: 2 additions & 3 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FunctionCreationGuard {
~FunctionCreationGuard();
};

class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
public:
static uint64 task_counter;

Expand All @@ -66,15 +66,14 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
llvm::BasicBlock *current_while_after_loop;
llvm::FunctionType *task_function_type;
OffloadedStmt *current_offloaded_stmt;
SNodeAttributes &snode_attr;
std::unordered_map<Stmt *, llvm::Value *> llvm_val;
llvm::Function *func;
std::unique_ptr<OffloadedTask> current_task;
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

using IRVisitor::visit;
using ModuleBuilder::call;
using LLVMModuleBuilder::call;

CodeGenLLVM(Kernel *kernel, IRNode *ir = nullptr);

Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

auto format_str = "[debug] " + stmt->str + " = " + format + "\n";

llvm_val[stmt] = ModuleBuilder::call(
llvm_val[stmt] = LLVMModuleBuilder::call(
builder.get(), "vprintf",
builder->CreateGlobalStringPtr(format_str, "format_string"),
builder->CreateBitCast(values,
Expand Down
24 changes: 0 additions & 24 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,28 +256,4 @@ class SNode {
uint64 fetch_reader_result(); // TODO: refactor
};

class SNodeAttribute {
public:
llvm::Type *llvm_type, *llvm_body_type, *llvm_aux_type;
llvm::Type *llvm_element_type;
};

class SNodeAttributes {
private:
std::map<SNode *, SNodeAttribute> snode_llvm_attr;

public:
SNodeAttribute &operator[](SNode &snode) {
return snode_llvm_attr[&snode];
}

SNodeAttribute &operator[](SNode *snode) {
return snode_llvm_attr[snode];
}

SNodeAttribute &operator[](const std::unique_ptr<SNode> &snode) {
return snode_llvm_attr[snode.get()];
}
};

TLANG_NAMESPACE_END
8 changes: 4 additions & 4 deletions taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ inline bool check_func_call_signature(llvm::Value *func, Args &&... args) {
return check_func_call_signature(func, {args...});
}

class ModuleBuilder {
class LLVMModuleBuilder {
public:
std::unique_ptr<llvm::Module> module;
llvm::BasicBlock *entry_block;
std::unique_ptr<llvm::IRBuilder<>> builder;
TaichiLLVMContext *tlctx;
llvm::LLVMContext *llvm_context;

ModuleBuilder(std::unique_ptr<llvm::Module> &&module)
LLVMModuleBuilder(std::unique_ptr<llvm::Module> &&module)
: module(std::move(module)) {
}

Expand Down Expand Up @@ -131,12 +131,12 @@ class RuntimeObject {
public:
std::string cls_name;
llvm::Value *ptr;
ModuleBuilder *mb;
LLVMModuleBuilder *mb;
llvm::Type *type;
llvm::IRBuilder<> *builder;

RuntimeObject(const std::string &cls_name,
ModuleBuilder *mb,
LLVMModuleBuilder *mb,
llvm::IRBuilder<> *builder,
llvm::Value *init = nullptr)
: cls_name(cls_name), mb(mb), builder(builder) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class TaichiLLVMContext {
std::mutex mut;
Arch arch;

SNodeAttributes snode_attr;

TaichiLLVMContext(Arch arch);

std::unique_ptr<llvm::Module> get_init_module();
Expand Down
17 changes: 9 additions & 8 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "taichi/codegen/codegen_opengl.h"
#include "taichi/codegen/codegen_cpu.h"
#include "taichi/struct/struct.h"
#include "taichi/struct/struct_llvm.h"
#include "taichi/struct/struct_metal.h"
#include "taichi/struct/struct_opengl.h"
#include "taichi/system/unified_allocator.h"
Expand Down Expand Up @@ -215,15 +216,15 @@ void Program::initialize_runtime_system(StructCompiler *scomp) {
for (int i = 0; i < (int)snodes.size(); i++) {
if (is_gc_able(snodes[i]->type)) {
std::size_t node_size;
if (snodes[i]->type == SNodeType::pointer)
node_size = tlctx->get_type_size(
scomp->snode_attr[snodes[i]].llvm_element_type);
else {
auto element_size =
tlctx->get_type_size(StructCompilerLLVM::get_llvm_element_type(
tlctx->struct_module.get(), snodes[i]));
if (snodes[i]->type == SNodeType::pointer) {
// pointer. Allocators are for single elements
node_size = element_size;
} else {
// dynamic. Allocators are for the chunks
node_size = sizeof(void *) +
tlctx->get_type_size(
scomp->snode_attr[snodes[i]].llvm_element_type) *
snodes[i]->chunk_size;
node_size = sizeof(void *) + element_size * snodes[i]->chunk_size;
}
TI_TRACE("Initializing allocator for snode {} (node size {})",
snodes[i]->id, node_size);
Expand Down
2 changes: 0 additions & 2 deletions taichi/struct/struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class StructCompiler {
std::size_t root_size;
Program *prog;

SNodeAttributes snode_attr;

explicit StructCompiler(Program *prog);

virtual ~StructCompiler() = default;
Expand Down
72 changes: 56 additions & 16 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace llvm;

StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch)
: StructCompiler(prog),
ModuleBuilder(prog->get_llvm_context(arch)->get_init_module()),
LLVMModuleBuilder(prog->get_llvm_context(arch)->get_init_module()),
arch(arch) {
tlctx = prog->get_llvm_context(arch);
llvm_ctx = tlctx->ctx.get();
Expand All @@ -22,23 +22,20 @@ StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch)
void StructCompilerLLVM::generate_types(SNode &snode) {
TI_AUTO_PROF;
auto type = snode.type;
llvm::Type *llvm_type = nullptr;
llvm::Type *node_type = nullptr;

auto ctx = llvm_ctx;

// create children type that supports forking...

std::vector<llvm::Type *> ch_types;
for (int i = 0; i < snode.ch.size(); i++) {
auto ch = snode_attr[snode.ch[i]].llvm_type;
auto ch = get_llvm_node_type(module.get(), snode.ch[i].get());
ch_types.push_back(ch);
}

auto ch_type =
llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch");
ch_type->setName(snode.node_type_name + "_ch");

snode_attr[snode].llvm_element_type = ch_type;

llvm::Type *body_type = nullptr, *aux_type = nullptr;
if (type == SNodeType::dense || type == SNodeType::bitmasked) {
Expand Down Expand Up @@ -69,15 +66,21 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
TI_NOT_IMPLEMENTED;
}
if (aux_type != nullptr) {
llvm_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "");
node_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "");
} else {
llvm_type = body_type;
node_type = body_type;
}

TI_ASSERT(llvm_type != nullptr);
snode_attr[snode].llvm_type = llvm_type;
snode_attr[snode].llvm_aux_type = aux_type;
snode_attr[snode].llvm_body_type = body_type;
TI_ASSERT(node_type != nullptr);
TI_ASSERT(body_type != nullptr);

// Here we create a stub holding 4 LLVM types as struct members.
// The aim is to give a **unique** name to the stub, so that we can look up
// these types using this name. This decouples them from the LLVM context.
// Note that body_type might not have a unique name, since literal structs
// (such as {i32, i32}) are uniqued in LLVM.
llvm::StructType::create(*ctx, {node_type, body_type, aux_type, ch_type},
type_stub_name(&snode));
}

void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) {
Expand Down Expand Up @@ -140,7 +143,7 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) {
auto parent = snode.parent;

auto inp_type =
llvm::PointerType::get(snode_attr[parent].llvm_element_type, 0);
llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0);

auto ft =
llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx),
Expand Down Expand Up @@ -174,6 +177,10 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) {
stack.pop_back();
}

std::string StructCompilerLLVM::type_stub_name(SNode *snode) {
return snode->node_type_name + "_type_stubs";
}

void StructCompilerLLVM::run(SNode &root, bool host) {
TI_AUTO_PROF;
// bottom to top
Expand All @@ -197,11 +204,44 @@ void StructCompilerLLVM::run(SNode &root, bool host) {

TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes);

root_size =
tlctx->get_data_layout().getTypeAllocSize(snode_attr[root].llvm_type);
auto node_type = get_llvm_node_type(module.get(), &root);
root_size = tlctx->get_data_layout().getTypeAllocSize(node_type);

tlctx->set_struct_module(module);
tlctx->snode_attr = snode_attr;
}

llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module,
SNode *snode,
uint32 index) {
TI_ASSERT(module);
TI_ASSERT(snode);
auto stub = module->getTypeByName(type_stub_name(snode));
TI_ASSERT(stub);
TI_ASSERT(stub->getStructNumElements() == 4);
TI_ASSERT(0 <= index && index < 4);
auto type = stub->getContainedType(index);
TI_ASSERT(type);
return type;
}

llvm::Type *StructCompilerLLVM::get_llvm_node_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 0);
}

llvm::Type *StructCompilerLLVM::get_llvm_body_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 1);
}

llvm::Type *StructCompilerLLVM::get_llvm_aux_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 2);
}

llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 3);
}

std::unique_ptr<StructCompiler> StructCompiler::make(Program *prog, Arch arch) {
Expand Down
16 changes: 14 additions & 2 deletions taichi/struct/struct_llvm.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Codegen for the hierarchical data structure (LLVM)

#include "struct.h"
#include "taichi/struct/struct.h"
#include "taichi/llvm/llvm_codegen_utils.h"

TLANG_NAMESPACE_BEGIN

class StructCompilerLLVM : public StructCompiler, public ModuleBuilder {
class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder {
public:
StructCompilerLLVM(Program *prog, Arch arch);

Expand All @@ -20,6 +20,18 @@ class StructCompilerLLVM : public StructCompiler, public ModuleBuilder {
void run(SNode &node, bool host) override;

void generate_refine_coordinates(SNode *snode);

static std::string type_stub_name(SNode *snode);

static llvm::Type *get_stub(llvm::Module *module, SNode *snode, uint32 index);

static llvm::Type *get_llvm_node_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_body_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_aux_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_element_type(llvm::Module *module, SNode *snode);
};

TLANG_NAMESPACE_END

0 comments on commit d5ce88b

Please sign in to comment.