Skip to content

Commit

Permalink
[type] [refactor] Add compute_type for CustomIntType (#2047)
Browse files Browse the repository at this point in the history
Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Yuanming Hu <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2020
1 parent 1f30800 commit 2b9b0a2
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 48 deletions.
44 changes: 30 additions & 14 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
TI_ASSERT(!to->is<CustomIntType>());
auto from_size = 0;
if (from->is<CustomIntType>()) {
// TODO: replace 32 with a customizable type
from_size = 32;
from_size =
data_type_size(from->cast<CustomIntType>()->get_compute_type());
} else {
from_size = data_type_size(from);
}
Expand Down Expand Up @@ -618,6 +618,10 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) {
return nullptr;
}

llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) {
return llvm::PointerType::get(llvm_type(dt), 0);
}

void CodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
Expand Down Expand Up @@ -907,8 +911,8 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits = 0;
if (stmt->value->ret_type->is<CustomIntType>()) {
intermediate_bits = 32;
if (auto cit = stmt->value->ret_type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits();
Expand Down Expand Up @@ -1105,12 +1109,15 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
auto cit = ptr_type->get_pointee_type()->as<CustomIntType>();
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
auto func_name = fmt::format("set_partial_bits_b{}",
data_type_bits(cit->get_physical_type()));
builder->CreateCall(
get_runtime_function("set_partial_bits_b32"),
get_runtime_function(func_name),
{builder->CreateBitCast(byte_ptr,
llvm::Type::getInt32PtrTy(*llvm_context)),
llvm_ptr_type(cit->get_physical_type())),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
llvm_val[stmt->data]});
builder->CreateIntCast(llvm_val[stmt->data],
llvm_type(cit->get_physical_type()), false)});
} else {
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
}
Expand All @@ -1124,23 +1131,31 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
// 1. load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);

auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)));
byte_ptr, llvm_ptr_type(cit->get_physical_type())));
// 2. bit shifting
// first left shift `32 - (offset + num_bits)`
// then right shift `32 - num_bits`
// first left shift `physical_type - (offset + num_bits)`
// then right shift `physical_type - num_bits`
auto bit_end = builder->CreateAdd(bit_offset,
tlctx->get_constant(cit->get_num_bits()));
auto left = builder->CreateSub(tlctx->get_constant(32), bit_end);
auto right = builder->CreateAdd(tlctx->get_constant(32),
tlctx->get_constant(-cit->get_num_bits()));
auto left = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end);
auto right = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())),
tlctx->get_constant(cit->get_num_bits()));
left = builder->CreateIntCast(left, bit_level_container->getType(), false);
right =
builder->CreateIntCast(right, bit_level_container->getType(), false);
auto step1 = builder->CreateShl(bit_level_container, left);
llvm::Value *step2 = nullptr;
if (cit->get_is_signed())
step2 = builder->CreateAShr(step1, right);
else
step2 = builder->CreateLShr(step1, right);
llvm_val[stmt] = step2;

llvm_val[stmt] = builder->CreateIntCast(
step2, llvm_type(cit->get_compute_type()), cit->get_is_signed());
} else {
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
llvm_val[stmt->ptr]);
Expand Down Expand Up @@ -1244,6 +1259,7 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
// };
auto struct_type = llvm::StructType::get(
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
// 2. alloca the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *llvm_type(DataType dt);

llvm::Type *llvm_ptr_type(DataType dt);

void visit(Block *stmt_list) override;

void visit(AllocaStmt *stmt) override;
Expand Down
30 changes: 30 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,36 @@ std::string CustomIntType::to_string() const {
return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_);
}

CustomIntType::CustomIntType(int num_bits,
bool is_signed,
Type *compute_type,
Type *physical_type)
: compute_type(compute_type),
physical_type(physical_type),
num_bits_(num_bits),
is_signed_(is_signed) {
if (compute_type == nullptr) {
auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32;
this->compute_type =
TypeFactory::get_instance().get_primitive_type(type_id);
}
}

BitStructType::BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
int physical_type_bits = data_type_bits(physical_type);
for (auto i = 0; i < member_types_.size(); ++i) {
auto bits_end = member_types_[i]->as<CustomIntType>()->get_num_bits() +
member_bit_offsets_[i];
TI_ASSERT(physical_type_bits >= bits_end)
}
}

std::string BitStructType::to_string() const {
std::string str = "bs(";
int num_members = (int)member_bit_offsets_.size();
Expand Down
34 changes: 24 additions & 10 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,29 @@ class VectorType : public Type {

class CustomIntType : public Type {
public:
CustomIntType(int num_bits, bool is_signed)
: num_bits_(num_bits), is_signed_(is_signed) {
CustomIntType(int num_bits,
bool is_signed,
Type *compute_type = nullptr,
Type *physical_type = nullptr);

~CustomIntType() override {
delete compute_type;
}

std::string to_string() const override;

void set_physical_type(Type *physical_type_) {
this->physical_type = physical_type_;
}

Type *get_physical_type() {
return physical_type;
}

Type *get_compute_type() {
return compute_type;
}

int get_num_bits() const {
return num_bits_;
}
Expand All @@ -183,20 +200,17 @@ class CustomIntType : public Type {
private:
// TODO(type): for now we can uniformly use i32 as the "compute_type". It may
// be a good idea to make "compute_type" also customizable.
int num_bits_;
bool is_signed_;
Type *compute_type{nullptr};
Type *physical_type{nullptr};
int num_bits_{32};
bool is_signed_{true};
};

class BitStructType : public Type {
public:
BitStructType(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets)
: physical_type_(physical_type),
member_types_(member_types),
member_bit_offsets_(member_bit_offsets) {
TI_ASSERT(member_types_.size() == member_bit_offsets_.size());
}
std::vector<int> member_bit_offsets);

std::string to_string() const override;

Expand Down
17 changes: 11 additions & 6 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
return pointer_types_[key].get();
}

Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) {
auto key = std::make_pair(num_bits, is_signed);
if (custom_int_types_.find(key) == custom_int_types_.end()) {
custom_int_types_[key] =
std::make_unique<CustomIntType>(num_bits, is_signed);
Type *TypeFactory::get_custom_int_type(int num_bits,
bool is_signed,
int compute_type_bits) {
auto key = std::make_tuple(compute_type_bits, num_bits, is_signed);
if (custom_int_types.find(key) == custom_int_types.end()) {
custom_int_types[key] = std::make_unique<CustomIntType>(
num_bits, is_signed,
get_primitive_int_type(compute_type_bits, is_signed));
}
return custom_int_types_[key].get();
return custom_int_types[key].get();
}

#undef SET_COMPUTE_TYPE

Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
std::vector<int> member_bit_offsets) {
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class TypeFactory {

Type *get_pointer_type(Type *element, bool is_bit_pointer = false);

Type *get_custom_int_type(int num_bits, bool is_signed);
Type *get_custom_int_type(int num_bits,
bool is_signed,
int compute_type_bits = 32);

Type *get_bit_struct_type(PrimitiveType *physical_type,
std::vector<Type *> member_types,
Expand All @@ -47,7 +49,7 @@ class TypeFactory {
std::map<std::pair<Type *, bool>, std::unique_ptr<Type>> pointer_types_;

// TODO: use unordered map
std::map<std::pair<int, bool>, std::unique_ptr<Type>> custom_int_types_;
std::map<std::tuple<int, int, bool>, std::unique_ptr<Type>> custom_int_types;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> bit_struct_types_;
Expand Down
4 changes: 4 additions & 0 deletions taichi/lang_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ int data_type_size(DataType t) {
}
}

int data_type_bits(DataType t) {
return data_type_size(t) * 8;
}

std::string data_type_short_name(DataType t) {
if (!t->is<PrimitiveType>()) {
return t->to_string();
Expand Down
1 change: 1 addition & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ std::string make_list(const std::vector<T> &data,
}

int data_type_size(DataType t);
int data_type_bits(DataType t);
DataType promoted_type(DataType a, DataType b);

extern std::string compiled_lib_dir;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ void export_lang(py::module &m) {
// TypeFactory on Python-scope pointer destruction.
py::class_<TypeFactory>(m, "TypeFactory")
.def("get_custom_int_type", &TypeFactory::get_custom_int_type,
py::arg("num_bits"), py::arg("is_signed"),
py::arg("compute_type_bits") = 32,
py::return_value_policy::reference);

m.def("get_type_factory_instance", TypeFactory::get_instance,
Expand Down
31 changes: 20 additions & 11 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ using float32 = float;
using float64 = double;

using i8 = int8;
using i16 = int16;
using i32 = int32;
using i64 = int64;
using u8 = uint8;
using u16 = uint16;
using u32 = uint32;
using u64 = uint64;
using f32 = float32;
Expand Down Expand Up @@ -1551,17 +1553,24 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {

#include "internal_functions.h"

void set_partial_bits_b32(u32 *ptr, u32 offset, u32 bits, u32 value) {
u32 mask = ((((u32)1 << bits) - 1) << offset);
u32 new_value = 0;
u32 old_value = *ptr;
do {
old_value = *ptr;
new_value = (old_value & (~mask)) | (value << offset);
} while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true,
std::memory_order::memory_order_seq_cst,
std::memory_order::memory_order_seq_cst));
}
#define DEFINE_SET_PARTIAL_BITS(N) \
void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \
u##N mask = ((((u##N)1 << bits) - 1) << offset); \
u##N new_value = 0; \
u##N old_value = *ptr; \
do { \
old_value = *ptr; \
new_value = (old_value & (~mask)) | (value << offset); \
} while ( \
!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \
std::memory_order::memory_order_seq_cst, \
std::memory_order::memory_order_seq_cst)); \
}

DEFINE_SET_PARTIAL_BITS(8);
DEFINE_SET_PARTIAL_BITS(16);
DEFINE_SET_PARTIAL_BITS(32);
DEFINE_SET_PARTIAL_BITS(64);
}

#endif
2 changes: 2 additions & 0 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
ch_types.push_back(ch->dt.get_ptr());
ch_offsets.push_back(total_offset);
total_offset += ch->dt->as<CustomIntType>()->get_num_bits();
ch->dt->as<CustomIntType>()->set_physical_type(snode.physical_type);
}

snode.dt = TypeFactory::get_instance().get_bit_struct_type(
Expand All @@ -80,6 +81,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
TI_ASSERT(snode.ch.size() == 1);
auto &ch = snode.ch[0];
Type *ch_type = ch->dt.get_ptr();
ch->dt->as<CustomIntType>()->set_physical_type(snode.physical_type);
snode.dt = TypeFactory::get_instance().get_bit_array_type(
snode.physical_type, ch_type, snode.n);

Expand Down
7 changes: 2 additions & 5 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,8 @@ class TypeCheck : public IRVisitor {

if (stmt->lhs->ret_type != stmt->rhs->ret_type) {
auto promote_custom_int_type = [&](Stmt *stmt, Stmt *hs) {
if (hs->ret_type->is<CustomIntType>()) {
if (hs->ret_type->cast<CustomIntType>()->get_is_signed())
return insert_type_cast_before(stmt, hs, get_data_type<int32>());
else
return insert_type_cast_before(stmt, hs, get_data_type<uint32>());
if (auto cit = hs->ret_type->cast<CustomIntType>()) {
return insert_type_cast_before(stmt, hs, cit->get_compute_type());
}
return hs;
};
Expand Down
Loading

0 comments on commit 2b9b0a2

Please sign in to comment.