-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[type] [refactor] Add compute_type for CustomIntType #2047
Changes from 8 commits
41b9965
0b8fd8b
b1b5e56
100ea35
3399981
e98f450
9ff8476
04cb2f3
15ffdfa
a44bd21
6d82d39
fab966e
7bfbcb5
1f9e511
9b109f5
10cd722
45dabf8
9d99341
f045ec1
e344ca0
a20a0e8
18a6738
b8c4d49
a7acc57
704bbe7
5022a10
c1bf0e6
cc8878b
a2d4196
c9993cf
29f68ad
985f35e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -328,7 +328,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto from_size = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (from->is<CustomIntType>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: replace 32 with a customizable type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from_size = 32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from_size = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type_size(from->cast<CustomIntType>()->get_compute_type()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from_size = data_type_size(from); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
yuanming-hu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -618,6 +619,31 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
return nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dt->is_primitive(PrimitiveTypeID::i8) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dt->is_primitive(PrimitiveTypeID::u8)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getInt8PtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::i16) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dt->is_primitive(PrimitiveTypeID::u16)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getInt16PtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::i32) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dt->is_primitive(PrimitiveTypeID::u32)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getInt32PtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::i64) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dt->is_primitive(PrimitiveTypeID::u64)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getInt64PtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getInt1PtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::f32)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getFloatPtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (dt->is_primitive(PrimitiveTypeID::f64)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return llvm::Type::getDoublePtrTy(*llvm_context); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TI_NOT_IMPLEMENTED; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void CodeGenLLVM::visit(TernaryOpStmt *stmt) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TI_ASSERT(stmt->op_type == TernaryOpType::select); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm_val[stmt] = builder->CreateSelect( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -907,8 +933,8 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
TI_NOT_IMPLEMENTED | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto intermediate_bits = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (stmt->value->ret_type->is<CustomIntType>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
intermediate_bits = 32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (auto cit = stmt->value->ret_type->cast<CustomIntType>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
intermediate_bits = data_type_bits(cit->get_compute_type()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
intermediate_bits = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1105,12 +1131,16 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto cit = ptr_type->get_pointee_type()->as<CustomIntType>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto runtime_func_name = fmt::format( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"set_partial_bits_b{}", data_type_bits(cit->get_compute_type())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
builder->CreateCall( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
get_runtime_function("set_partial_bits_b32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
get_runtime_function(runtime_func_name), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{builder->CreateBitCast(byte_ptr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm::Type::getInt32PtrTy(*llvm_context)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm_ptr_type(cit->get_compute_type())), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bit_offset, tlctx->get_constant(cit->get_num_bits()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm_val[stmt->data]}); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
builder->CreateIntCast(llvm_val[stmt->data], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm_type(cit->get_compute_type()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cit->get_is_signed())}); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1125,15 +1155,21 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm::Value *byte_ptr, *bit_offset; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context))); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
byte_ptr, llvm_ptr_type(cit->get_compute_type()))); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// 2. bit shifting | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// first left shift `32 - (offset + num_bits)` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// then right shift `32 - num_bits` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// first left shift `compute_type_size(like 32, 64, ...) - (offset + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// num_bits)` then right shift `compute_type_size - num_bits` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto compute_type_size = data_type_bits(cit->get_compute_type()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto bit_end = builder->CreateAdd(bit_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tlctx->get_constant(cit->get_num_bits())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto left = builder->CreateSub(tlctx->get_constant(32), bit_end); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto right = builder->CreateAdd(tlctx->get_constant(32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tlctx->get_constant(-cit->get_num_bits())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto left = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tlctx->get_constant(cit->get_num_bits())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
left = builder->CreateIntCast(left, llvm_type(cit->get_compute_type()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cit->get_is_signed()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
right = builder->CreateIntCast(right, llvm_type(cit->get_compute_type()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After testing the case of 16-bits and 64-bits, I found that we should use the num of |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cit->get_is_signed()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto step1 = builder->CreateShl(bit_level_container, left); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llvm::Value *step2 = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (cit->get_is_signed()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove TODO here