Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Add compute_type for CustomIntType #2047

Merged
merged 32 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
41b9965
rebase code from upstream
Hanke98 Nov 11, 2020
0b8fd8b
add llvm_ptr_type and data type bits
Hanke98 Nov 11, 2020
b1b5e56
add some comments
Hanke98 Nov 13, 2020
100ea35
[skip ci] enforce code format
taichi-gardener Nov 13, 2020
3399981
use type factory
Hanke98 Nov 13, 2020
e98f450
[skip ci] enforce code format
taichi-gardener Nov 13, 2020
9ff8476
update runtime
Hanke98 Nov 14, 2020
04cb2f3
[skip ci] enforce code format
taichi-gardener Nov 14, 2020
15ffdfa
modify bit_ptr struct and rebase
Hanke98 Nov 15, 2020
a44bd21
[skip ci] enforce code format
taichi-gardener Nov 15, 2020
6d82d39
fix bit_array
Hanke98 Nov 16, 2020
fab966e
use runtime to do global loading
Hanke98 Nov 16, 2020
7bfbcb5
[skip ci] enforce code format
taichi-gardener Nov 16, 2020
1f9e511
add tests
Hanke98 Nov 16, 2020
9b109f5
[skip ci] enforce code format
taichi-gardener Nov 16, 2020
10cd722
add more tetst cases and rebase
Hanke98 Nov 17, 2020
45dabf8
[skip ci] enforce code format
taichi-gardener Nov 17, 2020
9d99341
fix bit struct
Hanke98 Nov 18, 2020
f045ec1
set physical type for bit_array
Hanke98 Nov 19, 2020
e344ca0
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
a20a0e8
modify test cases
Hanke98 Nov 19, 2020
18a6738
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
b8c4d49
Apply suggestions from code review
Hanke98 Nov 19, 2020
a7acc57
fix type factory
Hanke98 Nov 19, 2020
704bbe7
change funtion name
Hanke98 Nov 19, 2020
5022a10
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
c1bf0e6
Apply suggestions from code review
Hanke98 Nov 20, 2020
cc8878b
[skip ci] enforce code format
taichi-gardener Nov 20, 2020
a2d4196
modify APIs
Hanke98 Nov 20, 2020
c9993cf
[skip ci] enforce code format
taichi-gardener Nov 20, 2020
29f68ad
update
Hanke98 Nov 20, 2020
985f35e
remove todo
Hanke98 Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
auto from_size = 0;
if (from->is<CustomIntType>()) {
// TODO: replace 32 with a customizable type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove TODO here

from_size = 32;
from_size =
data_type_size(from->cast<CustomIntType>()->get_compute_type());
} else {
from_size = data_type_size(from);
}
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -618,6 +619,31 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) {
return nullptr;
}

llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatPtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoublePtrTy(*llvm_context);
} else {
TI_NOT_IMPLEMENTED;
}
return nullptr;
}
Copy link
Member

@yuanming-hu yuanming-hu Nov 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) {
if (dt->is_primitive(PrimitiveTypeID::i8) ||
dt->is_primitive(PrimitiveTypeID::u8)) {
return llvm::Type::getInt8PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::u16)) {
return llvm::Type::getInt16PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::u32)) {
return llvm::Type::getInt32PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u64)) {
return llvm::Type::getInt64PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return llvm::Type::getInt1PtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
return llvm::Type::getFloatPtrTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::Type::getDoublePtrTy(*llvm_context);
} else {
TI_NOT_IMPLEMENTED;
}
return nullptr;
}
llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) {
return llvm::PointerType::get(llvm_type(dt), 0);
}


void CodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
Expand Down Expand Up @@ -907,8 +933,8 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits = 0;
if (stmt->value->ret_type->is<CustomIntType>()) {
intermediate_bits = 32;
if (auto cit = stmt->value->ret_type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits();
Expand Down Expand Up @@ -1103,14 +1129,16 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto cit = ptr_type->get_pointee_type()->as<CustomIntType>();
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr,
*physical_type_size = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset,
physical_type_size);
builder->CreateCall(
get_runtime_function("set_partial_bits_b32"),
get_runtime_function("set_partial_bits"),
{builder->CreateBitCast(byte_ptr,
llvm::Type::getInt32PtrTy(*llvm_context)),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
llvm_val[stmt->data]});
llvm_val[stmt->data], physical_type_size});
} else {
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
}
Expand All @@ -1122,25 +1150,16 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
if (stmt->ptr->ret_type->as<PointerType>()->is_bit_pointer()) {
auto cit = stmt->ret_type->as<CustomIntType>();
// 1. load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)));
// 2. bit shifting
// first left shift `32 - (offset + num_bits)`
// then right shift `32 - num_bits`
auto bit_end = builder->CreateAdd(bit_offset,
tlctx->get_constant(cit->get_num_bits()));
auto left = builder->CreateSub(tlctx->get_constant(32), bit_end);
auto right = builder->CreateAdd(tlctx->get_constant(32),
tlctx->get_constant(-cit->get_num_bits()));
auto step1 = builder->CreateShl(bit_level_container, left);
llvm::Value *step2 = nullptr;
if (cit->get_is_signed())
step2 = builder->CreateAShr(step1, right);
else
step2 = builder->CreateLShr(step1, right);
llvm_val[stmt] = step2;
llvm::Value *byte_ptr, *bit_offset, *physical_type_size;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset,
physical_type_size);
auto tmp = builder->CreateCall(
get_runtime_function("load_partial_bits"),
{byte_ptr, bit_offset, tlctx->get_constant(cit->get_num_bits()),
physical_type_size,
tlctx->get_constant((uint32)cit->get_is_signed())});
llvm_val[stmt] = builder->CreateIntCast(
tmp, llvm_type(cit->get_compute_type()), cit->get_is_signed());
} else {
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
llvm_val[stmt->ptr]);
Expand Down Expand Up @@ -1236,14 +1255,17 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) {
void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset) {
llvm::Value *bit_offset,
int num_bits) {
// 1. create a bit pointer struct
// struct bit_pointer {
// i8* byte_ptr;
// i32 offset;
// i32 physical_type_size;
// };
auto struct_type = llvm::StructType::get(
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
// 2. alloca the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
Expand All @@ -1257,6 +1279,11 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
builder->CreateStore(
bit_offset, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(1)}));
// 5. store `physical_type` in `bit_ptr_struct`
builder->CreateStore(
tlctx->get_constant(num_bits),
builder->CreateGEP(bit_ptr_struct,
{tlctx->get_constant(0), tlctx->get_constant(2)}));
return bit_ptr_struct;
}

Expand Down Expand Up @@ -1284,7 +1311,10 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
snode->dt.get_ptr()->as<BitArrayType>()->get_element_num_bits();
auto offset = tlctx->get_constant(element_num_bits);
offset = builder->CreateMul(offset, llvm_val[stmt->input_index]);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset);
llvm_val[stmt] = create_bit_ptr_struct(
llvm_val[stmt->input_snode], offset,
data_type_bits(
snode->dt.get_ptr()->as<BitArrayType>()->get_physical_type()));
} else {
TI_INFO(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand All @@ -1299,7 +1329,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
auto bit_offset = bit_struct->get_member_bit_offset(
stmt->input_snode->child_id(stmt->output_snode));
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset);
llvm_val[stmt] =
create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset,
data_type_bits(bit_struct->get_physical_type()));
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
Expand Down
5 changes: 4 additions & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *llvm_type(DataType dt);

llvm::Type *llvm_ptr_type(DataType dt);

void visit(Block *stmt_list) override;

void visit(AllocaStmt *stmt) override;
Expand Down Expand Up @@ -207,7 +209,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
void visit(IntegerOffsetStmt *stmt) override;

llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset);
llvm::Value *bit_offset,
int num_bits = 32);

void visit(SNodeLookupStmt *stmt) override;

Expand Down
52 changes: 52 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,58 @@ std::string CustomIntType::to_string() const {
return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_);
}

CustomIntType::CustomIntType(int num_bits, bool is_signed)
: compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) {
// TODO(type): support customizable compute_type
// and should we expose it to users?
Hanke98 marked this conversation as resolved.
Show resolved Hide resolved
TI_ASSERT(num_bits <= 32);
if (is_signed) {
compute_type =
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32);
} else {
compute_type =
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32);
}
}

#define SET_COMPUTE_TYPE(n, N) \
else if (n == N) { \
if (is_signed) \
type_id = PrimitiveTypeID::i##N; \
else \
type_id = PrimitiveTypeID::u##N; \
}

CustomIntType::CustomIntType(int compute_type_bits,
int num_bits,
bool is_signed)
: compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) {
auto type_id = PrimitiveTypeID::unknown;
if (false) {
}
SET_COMPUTE_TYPE(compute_type_bits, 64)
SET_COMPUTE_TYPE(compute_type_bits, 32)
SET_COMPUTE_TYPE(compute_type_bits, 16)
SET_COMPUTE_TYPE(compute_type_bits, 8)
else {TI_NOT_IMPLEMENTED} compute_type =
TypeFactory::get_instance().get_primitive_type(type_id);
}
Hanke98 marked this conversation as resolved.
Show resolved Hide resolved

BitStructType::BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
int physical_type_bits = data_type_bits(physical_type);
for (auto i = 0; i < member_types_.size(); ++i) {
auto bits_end = member_types_[i]->as<CustomIntType>()->get_num_bits() +
member_bit_offsets_[i];
TI_ASSERT(physical_type_bits >= bits_end)
}
}

std::string BitStructType::to_string() const {
std::string str = "bs(";
int num_members = (int)member_bit_offsets_.size();
Expand Down
34 changes: 24 additions & 10 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,20 @@ class VectorType : public Type {

class CustomIntType : public Type {
public:
CustomIntType(int num_bits, bool is_signed)
: num_bits_(num_bits), is_signed_(is_signed) {
CustomIntType(int num_bits, bool is_signed);

CustomIntType(int compute_type_bits, int numBits, bool isSigned);

~CustomIntType() override {
delete compute_type;
}

std::string to_string() const override;

Type *get_compute_type() {
return compute_type;
}

int get_num_bits() const {
return num_bits_;
}
Expand All @@ -183,20 +191,26 @@ class CustomIntType : public Type {
private:
// TODO(type): for now we can uniformly use i32 as the "compute_type". It may
// be a good idea to make "compute_type" also customizable.
int num_bits_;
bool is_signed_;
Type *compute_type{nullptr};
int num_bits_{32};
bool is_signed_{true};
};

class BitStructType : public Type {
public:
BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
}
std::vector<int> member_bit_offsets);
// : physical_type_(physical_type),
// member_types_(member_types),
// member_bit_offsets_(member_bit_offsets) {
// TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
// int physical_type_bits = data_type_bits(physical_type);
// for (auto i = 0; i < member_types_.size(); ++i) {
// auto bits_end = member_types_[i]->as<CustomIntType>()->get_num_bits()
// + member_bit_offsets_[i]; TI_ASSERT(physical_type_bits >= bits_end)
// }
// }

std::string to_string() const override;

Expand Down
12 changes: 12 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) {
return custom_int_types_[key].get();
}

Type *TypeFactory::_get_custom_int_type(int compute_type_bits,
int num_bits,
bool is_signed) {
auto key = std::make_pair(num_bits, is_signed);
if (custom_int_types_with_compute_types_.find(key) ==
custom_int_types_with_compute_types_.end()) {
custom_int_types_with_compute_types_[key] =
std::make_unique<CustomIntType>(compute_type_bits, num_bits, is_signed);
}
return custom_int_types_with_compute_types_[key].get();
}

Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets) {
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class TypeFactory {

Type *get_custom_int_type(int num_bits, bool is_signed);

Type *_get_custom_int_type(int compute_type_bits,
int num_bits,
bool is_signed);

Type *get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets);
Expand All @@ -49,6 +53,9 @@ class TypeFactory {
// TODO: use unordered map
std::map<std::pair<int, bool>, std::unique_ptr<Type>> custom_int_types_;

std::map<std::pair<int, bool>, std::unique_ptr<Type>>
custom_int_types_with_compute_types_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> bit_struct_types_;

Expand Down
4 changes: 4 additions & 0 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ int data_type_size(DataType t) {
}
}

int data_type_bits(DataType t) {
return data_type_size(t) * 8;
}

std::string data_type_short_name(DataType t) {
if (!t->is<PrimitiveType>()) {
return t->to_string();
Expand Down
1 change: 1 addition & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ std::string make_list(const std::vector<T> &data,
}

int data_type_size(DataType t);
int data_type_bits(DataType t);
DataType promoted_type(DataType a, DataType b);

extern std::string compiled_lib_dir;
Expand Down
8 changes: 7 additions & 1 deletion taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class LLVMModuleBuilder {

void read_bit_pointer(llvm::Value *ptr,
llvm::Value *&byte_ptr,
llvm::Value *&bit_offset) {
llvm::Value *&bit_offset,
llvm::Value *&physical_type) {
// 1. load byte pointer
auto byte_ptr_in_bit_struct = builder->CreateGEP(
ptr, {tlctx->get_constant(0), tlctx->get_constant(0)});
Expand All @@ -140,6 +141,11 @@ class LLVMModuleBuilder {
ptr, {tlctx->get_constant(0), tlctx->get_constant(1)});
bit_offset = builder->CreateLoad(bit_offset_in_bit_struct);
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));

auto physical_type_bit_struct = builder->CreateGEP(
ptr, {tlctx->get_constant(0), tlctx->get_constant(2)});
physical_type = builder->CreateLoad(physical_type_bit_struct);
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
}
};

Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ void export_lang(py::module &m) {
// TypeFactory on Python-scope pointer destruction.
py::class_<TypeFactory>(m, "TypeFactory")
.def("get_custom_int_type", &TypeFactory::get_custom_int_type,
py::return_value_policy::reference)
.def("_get_custom_int_type", &TypeFactory::_get_custom_int_type,
py::return_value_policy::reference);

m.def("get_type_factory_instance", TypeFactory::get_instance,
Expand Down
Loading