Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove SNodeAttr and decouple SNode from LLVM #817

Merged
merged 3 commits into from
Apr 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "codegen_llvm.h"

#include "taichi/struct/struct_llvm.h"

TLANG_NAMESPACE_BEGIN

// TODO: sort function definitions to match declaration order in header
Expand Down Expand Up @@ -206,14 +208,6 @@ std::unique_ptr<RuntimeObject> CodeGenLLVM::emit_struct_meta_object(
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED;
}
if (false) {
// auto ptr_type = llvm::Type::getInt8PtrTy(*llvm_context, 0);
auto ptr_type = llvm::PointerType::get(meta->type, 0);
auto ptr = meta->ptr; // builder->CreatePointerCast(meta->ptr, ptr_type);
auto struct_meta_size = tlctx->get_type_size(meta->type);
builder->CreateIntrinsic(llvm::Intrinsic::invariant_start, {ptr_type},
{tlctx->get_constant(struct_meta_size), ptr});
}
return meta;
}

Expand All @@ -223,13 +217,17 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,
RuntimeObject common("StructMeta", this, builder.get(), node_meta);
std::size_t element_size;
if (snode->type == SNodeType::dense) {
auto element_ty = snode_attr[snode].llvm_body_type->getArrayElementType();
auto body_type =
StructCompilerLLVM::get_llvm_body_type(module.get(), snode);
auto element_ty = body_type->getArrayElementType();
element_size = tlctx->get_type_size(element_ty);
} else if (snode->type == SNodeType::pointer) {
auto element_ty = tlctx->snode_attr[snode->ch[0]].llvm_type;
auto element_ty = StructCompilerLLVM::get_llvm_node_type(
module.get(), snode->ch[0].get());
element_size = tlctx->get_type_size(element_ty);
} else {
auto element_ty = tlctx->snode_attr[snode].llvm_element_type;
auto element_ty =
StructCompilerLLVM::get_llvm_element_type(module.get(), snode);
element_size = tlctx->get_type_size(element_ty);
}
common.set("snode_id", tlctx->get_constant(snode->id));
Expand Down Expand Up @@ -266,13 +264,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,
}

CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
// TODO: simplify ModuleBuilder ctor input
: ModuleBuilder(kernel->program.get_llvm_context(kernel->arch)
->clone_struct_module()),
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(kernel->program.get_llvm_context(kernel->arch)
->clone_struct_module()),
kernel(kernel),
ir(ir),
prog(&kernel->program),
snode_attr(prog->get_llvm_context(kernel->arch)->snode_attr) {
prog(&kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir;
initialize_context();
Expand Down Expand Up @@ -1117,8 +1114,9 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,

void CodeGenLLVM::visit(GetRootStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
get_root(),
PointerType::get(snode_attr[prog->snode_root.get()].llvm_type, 0));
get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->snode_root.get()),
0));
}

void CodeGenLLVM::visit(OffsetAndExtractBitsStmt *stmt) {
Expand Down Expand Up @@ -1184,7 +1182,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
{builder->CreateBitCast(llvm_val[stmt->input_ptr],
PointerType::getInt8PtrTy(*llvm_context))});
llvm_val[stmt] = builder->CreateBitCast(
ch, PointerType::get(snode_attr[stmt->output_snode].llvm_type, 0));
ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type(
module.get(), stmt->output_snode),
0));
}

void CodeGenLLVM::visit(ExternalPtrStmt *stmt) {
Expand Down
5 changes: 2 additions & 3 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FunctionCreationGuard {
~FunctionCreationGuard();
};

class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
public:
static uint64 task_counter;

Expand All @@ -66,15 +66,14 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
llvm::BasicBlock *current_while_after_loop;
llvm::FunctionType *task_function_type;
OffloadedStmt *current_offloaded_stmt;
SNodeAttributes &snode_attr;
std::unordered_map<Stmt *, llvm::Value *> llvm_val;
llvm::Function *func;
std::unique_ptr<OffloadedTask> current_task;
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

using IRVisitor::visit;
using ModuleBuilder::call;
using LLVMModuleBuilder::call;

CodeGenLLVM(Kernel *kernel, IRNode *ir = nullptr);

Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

auto format_str = "[debug] " + stmt->str + " = " + format + "\n";

llvm_val[stmt] = ModuleBuilder::call(
llvm_val[stmt] = LLVMModuleBuilder::call(
builder.get(), "vprintf",
builder->CreateGlobalStringPtr(format_str, "format_string"),
builder->CreateBitCast(values,
Expand Down
24 changes: 0 additions & 24 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,28 +256,4 @@ class SNode {
uint64 fetch_reader_result(); // TODO: refactor
};

class SNodeAttribute {
public:
llvm::Type *llvm_type, *llvm_body_type, *llvm_aux_type;
llvm::Type *llvm_element_type;
};

class SNodeAttributes {
private:
std::map<SNode *, SNodeAttribute> snode_llvm_attr;

public:
SNodeAttribute &operator[](SNode &snode) {
return snode_llvm_attr[&snode];
}

SNodeAttribute &operator[](SNode *snode) {
return snode_llvm_attr[snode];
}

SNodeAttribute &operator[](const std::unique_ptr<SNode> &snode) {
return snode_llvm_attr[snode.get()];
}
};

TLANG_NAMESPACE_END
8 changes: 4 additions & 4 deletions taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ inline bool check_func_call_signature(llvm::Value *func, Args &&... args) {
return check_func_call_signature(func, {args...});
}

class ModuleBuilder {
class LLVMModuleBuilder {
public:
std::unique_ptr<llvm::Module> module;
llvm::BasicBlock *entry_block;
std::unique_ptr<llvm::IRBuilder<>> builder;
TaichiLLVMContext *tlctx;
llvm::LLVMContext *llvm_context;

ModuleBuilder(std::unique_ptr<llvm::Module> &&module)
LLVMModuleBuilder(std::unique_ptr<llvm::Module> &&module)
: module(std::move(module)) {
}

Expand Down Expand Up @@ -131,12 +131,12 @@ class RuntimeObject {
public:
std::string cls_name;
llvm::Value *ptr;
ModuleBuilder *mb;
LLVMModuleBuilder *mb;
llvm::Type *type;
llvm::IRBuilder<> *builder;

RuntimeObject(const std::string &cls_name,
ModuleBuilder *mb,
LLVMModuleBuilder *mb,
llvm::IRBuilder<> *builder,
llvm::Value *init = nullptr)
: cls_name(cls_name), mb(mb), builder(builder) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class TaichiLLVMContext {
JITModule *runtime_jit_module;
Arch arch;

SNodeAttributes snode_attr;

TaichiLLVMContext(Arch arch);

std::unique_ptr<llvm::Module> get_init_module();
Expand Down
17 changes: 9 additions & 8 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "taichi/codegen/codegen_opengl.h"
#include "taichi/codegen/codegen_cpu.h"
#include "taichi/struct/struct.h"
#include "taichi/struct/struct_llvm.h"
#include "taichi/struct/struct_metal.h"
#include "taichi/struct/struct_opengl.h"
#include "taichi/system/unified_allocator.h"
Expand Down Expand Up @@ -213,15 +214,15 @@ void Program::initialize_runtime_system(StructCompiler *scomp) {
for (int i = 0; i < (int)snodes.size(); i++) {
if (is_gc_able(snodes[i]->type)) {
std::size_t node_size;
if (snodes[i]->type == SNodeType::pointer)
node_size = tlctx->get_type_size(
scomp->snode_attr[snodes[i]].llvm_element_type);
else {
auto element_size =
tlctx->get_type_size(StructCompilerLLVM::get_llvm_element_type(
tlctx->struct_module.get(), snodes[i]));
if (snodes[i]->type == SNodeType::pointer) {
// pointer. Allocators are for single elements
node_size = element_size;
} else {
// dynamic. Allocators are for the chunks
node_size = sizeof(void *) +
tlctx->get_type_size(
scomp->snode_attr[snodes[i]].llvm_element_type) *
snodes[i]->chunk_size;
node_size = sizeof(void *) + element_size * snodes[i]->chunk_size;
}
TI_TRACE("Initializing allocator for snode {} (node size {})",
snodes[i]->id, node_size);
Expand Down
2 changes: 0 additions & 2 deletions taichi/struct/struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class StructCompiler {
std::size_t root_size;
Program *prog;

SNodeAttributes snode_attr;

explicit StructCompiler(Program *prog);

virtual ~StructCompiler() = default;
Expand Down
72 changes: 56 additions & 16 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace llvm;

StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch)
: StructCompiler(prog),
ModuleBuilder(prog->get_llvm_context(arch)->get_init_module()),
LLVMModuleBuilder(prog->get_llvm_context(arch)->get_init_module()),
arch(arch) {
tlctx = prog->get_llvm_context(arch);
llvm_ctx = tlctx->ctx.get();
Expand All @@ -22,23 +22,20 @@ StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch)
void StructCompilerLLVM::generate_types(SNode &snode) {
TI_AUTO_PROF;
auto type = snode.type;
llvm::Type *llvm_type = nullptr;
llvm::Type *node_type = nullptr;

auto ctx = llvm_ctx;

// create children type that supports forking...

std::vector<llvm::Type *> ch_types;
for (int i = 0; i < snode.ch.size(); i++) {
auto ch = snode_attr[snode.ch[i]].llvm_type;
auto ch = get_llvm_node_type(module.get(), snode.ch[i].get());
ch_types.push_back(ch);
}

auto ch_type =
llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch");
ch_type->setName(snode.node_type_name + "_ch");

snode_attr[snode].llvm_element_type = ch_type;

llvm::Type *body_type = nullptr, *aux_type = nullptr;
if (type == SNodeType::dense || type == SNodeType::bitmasked) {
Expand Down Expand Up @@ -69,15 +66,21 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
TI_NOT_IMPLEMENTED;
}
if (aux_type != nullptr) {
llvm_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "");
node_type = llvm::StructType::create(*ctx, {aux_type, body_type}, "");
} else {
llvm_type = body_type;
node_type = body_type;
}

TI_ASSERT(llvm_type != nullptr);
snode_attr[snode].llvm_type = llvm_type;
snode_attr[snode].llvm_aux_type = aux_type;
snode_attr[snode].llvm_body_type = body_type;
TI_ASSERT(node_type != nullptr);
TI_ASSERT(body_type != nullptr);

// Here we create a stub holding 4 LLVM types as struct members.
// The aim is to give a **unique** name to the stub, so that we can look up
// these types using this name. This decouples them from the LLVM context.
// Note that body_type might not have a unique name, since literal structs
// (such as {i32, i32}) are uniqued in LLVM.
llvm::StructType::create(*ctx, {node_type, body_type, aux_type, ch_type},
type_stub_name(&snode));
}

void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) {
Expand Down Expand Up @@ -140,7 +143,7 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) {
auto parent = snode.parent;

auto inp_type =
llvm::PointerType::get(snode_attr[parent].llvm_element_type, 0);
llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0);

auto ft =
llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx),
Expand Down Expand Up @@ -174,6 +177,10 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) {
stack.pop_back();
}

std::string StructCompilerLLVM::type_stub_name(SNode *snode) {
return snode->node_type_name + "_type_stubs";
}

void StructCompilerLLVM::run(SNode &root, bool host) {
TI_AUTO_PROF;
// bottom to top
Expand All @@ -197,11 +204,44 @@ void StructCompilerLLVM::run(SNode &root, bool host) {

TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes);

root_size =
tlctx->get_data_layout().getTypeAllocSize(snode_attr[root].llvm_type);
auto node_type = get_llvm_node_type(module.get(), &root);
root_size = tlctx->get_data_layout().getTypeAllocSize(node_type);

tlctx->set_struct_module(module);
tlctx->snode_attr = snode_attr;
}

llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module,
SNode *snode,
uint32 index) {
TI_ASSERT(module);
TI_ASSERT(snode);
auto stub = module->getTypeByName(type_stub_name(snode));
TI_ASSERT(stub);
TI_ASSERT(stub->getStructNumElements() == 4);
TI_ASSERT(0 <= index && index < 4);
auto type = stub->getContainedType(index);
TI_ASSERT(type);
return type;
}

llvm::Type *StructCompilerLLVM::get_llvm_node_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 0);
}

llvm::Type *StructCompilerLLVM::get_llvm_body_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 1);
}

llvm::Type *StructCompilerLLVM::get_llvm_aux_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 2);
}

llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module,
SNode *snode) {
return get_stub(module, snode, 3);
}

std::unique_ptr<StructCompiler> StructCompiler::make(Program *prog, Arch arch) {
Expand Down
16 changes: 14 additions & 2 deletions taichi/struct/struct_llvm.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Codegen for the hierarchical data structure (LLVM)

#include "struct.h"
#include "taichi/struct/struct.h"
#include "taichi/llvm/llvm_codegen_utils.h"

TLANG_NAMESPACE_BEGIN

class StructCompilerLLVM : public StructCompiler, public ModuleBuilder {
class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder {
public:
StructCompilerLLVM(Program *prog, Arch arch);

Expand All @@ -20,6 +20,18 @@ class StructCompilerLLVM : public StructCompiler, public ModuleBuilder {
void run(SNode &node, bool host) override;

void generate_refine_coordinates(SNode *snode);

static std::string type_stub_name(SNode *snode);

static llvm::Type *get_stub(llvm::Module *module, SNode *snode, uint32 index);

static llvm::Type *get_llvm_node_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_body_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_aux_type(llvm::Module *module, SNode *snode);

static llvm::Type *get_llvm_element_type(llvm::Module *module, SNode *snode);
};

TLANG_NAMESPACE_END