Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Add compute_type for CustomIntType #2047

Merged
merged 32 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
41b9965
rebase code from upstream
Hanke98 Nov 11, 2020
0b8fd8b
add llvm_ptr_type and data type bits
Hanke98 Nov 11, 2020
b1b5e56
add some comments
Hanke98 Nov 13, 2020
100ea35
[skip ci] enforce code format
taichi-gardener Nov 13, 2020
3399981
use type factory
Hanke98 Nov 13, 2020
e98f450
[skip ci] enforce code format
taichi-gardener Nov 13, 2020
9ff8476
update runtime
Hanke98 Nov 14, 2020
04cb2f3
[skip ci] enforce code format
taichi-gardener Nov 14, 2020
15ffdfa
modify bit_ptr struct and rebase
Hanke98 Nov 15, 2020
a44bd21
[skip ci] enforce code format
taichi-gardener Nov 15, 2020
6d82d39
fix bit_array
Hanke98 Nov 16, 2020
fab966e
use runtime to do global loading
Hanke98 Nov 16, 2020
7bfbcb5
[skip ci] enforce code format
taichi-gardener Nov 16, 2020
1f9e511
add tests
Hanke98 Nov 16, 2020
9b109f5
[skip ci] enforce code format
taichi-gardener Nov 16, 2020
10cd722
add more tetst cases and rebase
Hanke98 Nov 17, 2020
45dabf8
[skip ci] enforce code format
taichi-gardener Nov 17, 2020
9d99341
fix bit struct
Hanke98 Nov 18, 2020
f045ec1
set physical type for bit_array
Hanke98 Nov 19, 2020
e344ca0
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
a20a0e8
modify test cases
Hanke98 Nov 19, 2020
18a6738
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
b8c4d49
Apply suggestions from code review
Hanke98 Nov 19, 2020
a7acc57
fix type factory
Hanke98 Nov 19, 2020
704bbe7
change funtion name
Hanke98 Nov 19, 2020
5022a10
[skip ci] enforce code format
taichi-gardener Nov 19, 2020
c1bf0e6
Apply suggestions from code review
Hanke98 Nov 20, 2020
cc8878b
[skip ci] enforce code format
taichi-gardener Nov 20, 2020
a2d4196
modify APIs
Hanke98 Nov 20, 2020
c9993cf
[skip ci] enforce code format
taichi-gardener Nov 20, 2020
29f68ad
update
Hanke98 Nov 20, 2020
985f35e
remove todo
Hanke98 Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
TI_ASSERT(!to->is<CustomIntType>());
auto from_size = 0;
if (from->is<CustomIntType>()) {
// TODO: replace 32 with a customizable type
from_size = 32;
from_size =
data_type_size(from->cast<CustomIntType>()->get_compute_type());
} else {
from_size = data_type_size(from);
}
Expand Down Expand Up @@ -618,6 +618,10 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) {
return nullptr;
}

llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) {
return llvm::PointerType::get(llvm_type(dt), 0);
}

void CodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
Expand Down Expand Up @@ -907,8 +911,8 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits = 0;
if (stmt->value->ret_type->is<CustomIntType>()) {
intermediate_bits = 32;
if (auto cit = stmt->value->ret_type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits();
Expand Down Expand Up @@ -1105,12 +1109,15 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
auto cit = ptr_type->get_pointee_type()->as<CustomIntType>();
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
auto func_name = fmt::format("set_partial_bits_b{}",
data_type_bits(cit->get_physical_type()));
builder->CreateCall(
get_runtime_function("set_partial_bits_b32"),
get_runtime_function(func_name),
{builder->CreateBitCast(byte_ptr,
llvm::Type::getInt32PtrTy(*llvm_context)),
llvm_ptr_type(cit->get_physical_type())),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
llvm_val[stmt->data]});
builder->CreateIntCast(llvm_val[stmt->data],
llvm_type(cit->get_physical_type()), false)});
} else {
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
}
Expand All @@ -1124,23 +1131,31 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
// 1. load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);

auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)));
byte_ptr, llvm_ptr_type(cit->get_physical_type())));
// 2. bit shifting
// first left shift `32 - (offset + num_bits)`
// then right shift `32 - num_bits`
// first left shift `physical_type - (offset + num_bits)`
// then right shift `physical_type - num_bits`
auto bit_end = builder->CreateAdd(bit_offset,
tlctx->get_constant(cit->get_num_bits()));
auto left = builder->CreateSub(tlctx->get_constant(32), bit_end);
auto right = builder->CreateAdd(tlctx->get_constant(32),
tlctx->get_constant(-cit->get_num_bits()));
auto left = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end);
auto right = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())),
tlctx->get_constant(cit->get_num_bits()));
left = builder->CreateIntCast(left, bit_level_container->getType(), false);
right =
builder->CreateIntCast(right, bit_level_container->getType(), false);
auto step1 = builder->CreateShl(bit_level_container, left);
llvm::Value *step2 = nullptr;
if (cit->get_is_signed())
step2 = builder->CreateAShr(step1, right);
else
step2 = builder->CreateLShr(step1, right);
llvm_val[stmt] = step2;

llvm_val[stmt] = builder->CreateIntCast(
step2, llvm_type(cit->get_compute_type()), cit->get_is_signed());
} else {
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
llvm_val[stmt->ptr]);
Expand Down Expand Up @@ -1244,6 +1259,7 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
// };
auto struct_type = llvm::StructType::get(
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
// 2. alloca the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *llvm_type(DataType dt);

llvm::Type *llvm_ptr_type(DataType dt);

void visit(Block *stmt_list) override;

void visit(AllocaStmt *stmt) override;
Expand Down
30 changes: 30 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,36 @@ std::string CustomIntType::to_string() const {
return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_);
}

CustomIntType::CustomIntType(int num_bits,
bool is_signed,
Type *compute_type,
Type *physical_type)
: compute_type(compute_type),
physical_type(physical_type),
num_bits_(num_bits),
is_signed_(is_signed) {
if (compute_type == nullptr) {
auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32;
this->compute_type =
TypeFactory::get_instance().get_primitive_type(type_id);
}
}
Hanke98 marked this conversation as resolved.
Show resolved Hide resolved

BitStructType::BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
int physical_type_bits = data_type_bits(physical_type);
for (auto i = 0; i < member_types_.size(); ++i) {
auto bits_end = member_types_[i]->as<CustomIntType>()->get_num_bits() +
member_bit_offsets_[i];
TI_ASSERT(physical_type_bits >= bits_end)
}
}

std::string BitStructType::to_string() const {
std::string str = "bs(";
int num_members = (int)member_bit_offsets_.size();
Expand Down
34 changes: 24 additions & 10 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,29 @@ class VectorType : public Type {

class CustomIntType : public Type {
public:
CustomIntType(int num_bits, bool is_signed)
: num_bits_(num_bits), is_signed_(is_signed) {
CustomIntType(int num_bits,
bool is_signed,
Type *compute_type = nullptr,
Type *physical_type = nullptr);

~CustomIntType() override {
delete compute_type;
}

std::string to_string() const override;

void set_physical_type(Type *physical_type_) {
this->physical_type = physical_type_;
}

Type *get_physical_type() {
return physical_type;
}

Type *get_compute_type() {
return compute_type;
}

int get_num_bits() const {
return num_bits_;
}
Expand All @@ -183,20 +200,17 @@ class CustomIntType : public Type {
private:
// TODO(type): for now we can uniformly use i32 as the "compute_type". It may
// be a good idea to make "compute_type" also customizable.
int num_bits_;
bool is_signed_;
Type *compute_type{nullptr};
Type *physical_type{nullptr};
int num_bits_{32};
bool is_signed_{true};
};

class BitStructType : public Type {
public:
BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
}
std::vector<int> member_bit_offsets);

std::string to_string() const override;

Expand Down
17 changes: 11 additions & 6 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
return pointer_types_[key].get();
}

Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) {
auto key = std::make_pair(num_bits, is_signed);
if (custom_int_types_.find(key) == custom_int_types_.end()) {
custom_int_types_[key] =
std::make_unique<CustomIntType>(num_bits, is_signed);
Type *TypeFactory::get_custom_int_type(int num_bits,
bool is_signed,
int compute_type_bits) {
auto key = std::make_tuple(compute_type_bits, num_bits, is_signed);
if (custom_int_types.find(key) == custom_int_types.end()) {
custom_int_types[key] = std::make_unique<CustomIntType>(
num_bits, is_signed,
get_primitive_int_type(compute_type_bits, is_signed));
}
return custom_int_types_[key].get();
return custom_int_types[key].get();
}

#undef SET_COMPUTE_TYPE

Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets) {
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class TypeFactory {

Type *get_pointer_type(Type *element, bool is_bit_pointer = false);

Type *get_custom_int_type(int num_bits, bool is_signed);
Type *get_custom_int_type(int num_bits,
bool is_signed,
int compute_type_bits = 32);

Type *get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
Expand All @@ -47,7 +49,7 @@ class TypeFactory {
std::map<std::pair<Type *, bool>, std::unique_ptr<Type>> pointer_types_;

// TODO: use unordered map
std::map<std::pair<int, bool>, std::unique_ptr<Type>> custom_int_types_;
std::map<std::tuple<int, int, bool>, std::unique_ptr<Type>> custom_int_types;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> bit_struct_types_;
Expand Down
4 changes: 4 additions & 0 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ int data_type_size(DataType t) {
}
}

int data_type_bits(DataType t) {
return data_type_size(t) * 8;
}

std::string data_type_short_name(DataType t) {
if (!t->is<PrimitiveType>()) {
return t->to_string();
Expand Down
1 change: 1 addition & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ std::string make_list(const std::vector<T> &data,
}

int data_type_size(DataType t);
int data_type_bits(DataType t);
DataType promoted_type(DataType a, DataType b);

extern std::string compiled_lib_dir;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ void export_lang(py::module &m) {
// TypeFactory on Python-scope pointer destruction.
py::class_<TypeFactory>(m, "TypeFactory")
.def("get_custom_int_type", &TypeFactory::get_custom_int_type,
py::arg("num_bits"), py::arg("is_signed"),
py::arg("compute_type_bits") = 32,
py::return_value_policy::reference);

m.def("get_type_factory_instance", TypeFactory::get_instance,
Expand Down
31 changes: 20 additions & 11 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ using float32 = float;
using float64 = double;

using i8 = int8;
using i16 = int16;
using i32 = int32;
using i64 = int64;
using u8 = uint8;
using u16 = uint16;
using u32 = uint32;
using u64 = uint64;
using f32 = float32;
Expand Down Expand Up @@ -1551,17 +1553,24 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {

#include "internal_functions.h"

void set_partial_bits_b32(u32 *ptr, u32 offset, u32 bits, u32 value) {
u32 mask = ((((u32)1 << bits) - 1) << offset);
u32 new_value = 0;
u32 old_value = *ptr;
do {
old_value = *ptr;
new_value = (old_value & (~mask)) | (value << offset);
} while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true,
std::memory_order::memory_order_seq_cst,
std::memory_order::memory_order_seq_cst));
}
#define DEFINE_SET_PARTIAL_BITS(N) \
void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
u##N new_value = 0; \
u##N old_value = *ptr; \
do { \
old_value = *ptr; \
new_value = (old_value & (~mask)) | (value << offset); \
} while ( \
!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \
std::memory_order::memory_order_seq_cst, \
std::memory_order::memory_order_seq_cst)); \
}

DEFINE_SET_PARTIAL_BITS(8);
DEFINE_SET_PARTIAL_BITS(16);
DEFINE_SET_PARTIAL_BITS(32);
DEFINE_SET_PARTIAL_BITS(64);
}

#endif
2 changes: 2 additions & 0 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
ch_types.push_back(ch->dt.get_ptr());
ch_offsets.push_back(total_offset);
total_offset += ch->dt->as<CustomIntType>()->get_num_bits();
ch->dt->as<CustomIntType>()->set_physical_type(snode.physical_type);
}

snode.dt = TypeFactory::get_instance().get_bit_struct_type(
Expand All @@ -80,6 +81,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
TI_ASSERT(snode.ch.size() == 1);
auto &ch = snode.ch[0];
Type *ch_type = ch->dt.get_ptr();
ch->dt->as<CustomIntType>()->set_physical_type(snode.physical_type);
snode.dt = TypeFactory::get_instance().get_bit_array_type(
snode.physical_type, ch_type, snode.n);

Expand Down
7 changes: 2 additions & 5 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,8 @@ class TypeCheck : public IRVisitor {

if (stmt->lhs->ret_type != stmt->rhs->ret_type) {
auto promote_custom_int_type = [&](Stmt *stmt, Stmt *hs) {
if (hs->ret_type->is<CustomIntType>()) {
if (hs->ret_type->cast<CustomIntType>()->get_is_signed())
return insert_type_cast_before(stmt, hs, get_data_type<int32>());
else
return insert_type_cast_before(stmt, hs, get_data_type<uint32>());
if (auto cit = hs->ret_type->cast<CustomIntType>()) {
return insert_type_cast_before(stmt, hs, cit->get_compute_type());
}
return hs;
};
Expand Down
Loading