Skip to content

Commit

Permalink
[type] [refactor] Decouple quant from SNode 3/n: Extend bit pointers (#…
Browse files Browse the repository at this point in the history
…5232)

* [type] [refactor] Decouple quant from SNode 3/n: Extend bit pointers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jun 23, 2022
1 parent 46b5632 commit 6357cf0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 153 deletions.
4 changes: 2 additions & 2 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,10 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
auto val_type = ptr_type->get_pointee_type();
if (auto qit = val_type->cast<QuantIntType>()) {
dtype = get_ch->input_snode->physical_type;
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]);
auto [data_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_quant_int(data, bit_offset, qit, dtype);
llvm_val[stmt] = extract_quant_int(data, bit_offset, qit);
} else {
// TODO: support __ldg
TI_ASSERT(val_type->is<QuantFixedType>() ||
Expand Down
94 changes: 36 additions & 58 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1195,11 +1195,9 @@ llvm::Value *CodeGenLLVM::quant_type_atomic(AtomicOpStmt *stmt) {

auto dst_type = stmt->dest->ret_type->as<PointerType>()->get_pointee_type();
if (auto qit = dst_type->cast<QuantIntType>()) {
return atomic_add_quant_int(
stmt, qit, stmt->dest->as<GetChStmt>()->input_snode->physical_type);
return atomic_add_quant_int(stmt, qit);
} else if (auto qfxt = dst_type->cast<QuantFixedType>()) {
return atomic_add_quant_fixed(
stmt, qfxt, stmt->dest->as<GetChStmt>()->input_snode->physical_type);
return atomic_add_quant_fixed(stmt, qfxt);
} else {
return nullptr;
}
Expand Down Expand Up @@ -1354,7 +1352,6 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
}
}
store_quant_int(llvm_val[stmt->dest], pointee_type->as<QuantIntType>(),
stmt->dest->as<GetChStmt>()->input_snode->physical_type,
llvm_val[stmt->val], true);
} else {
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
Expand All @@ -1368,9 +1365,7 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
if (auto qit = val_type->cast<QuantIntType>()) {
llvm_val[stmt] = load_quant_int(
llvm_val[stmt->src], qit,
stmt->src->as<GetChStmt>()->input_snode->physical_type);
llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], qit);
} else {
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
Expand Down Expand Up @@ -1479,61 +1474,44 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) {

void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset) {
// 1. get the bit pointer LLVM struct
// struct bit_pointer {
// i8* byte_ptr;
// i32 offset;
llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr,
llvm::Value *bit_offset) {
// 1. define the bit pointer struct (X=8/16/32/64)
// struct bit_pointer_X {
// iX* byte_ptr;
// i32 bit_offset;
// };
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
auto struct_type = llvm::StructType::get(
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
*llvm_context, {byte_ptr->getType(), bit_offset->getType()});
// 2. allocate the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
// 3. store `byte_ptr_base` into `bit_ptr_struct` (if provided)
if (byte_ptr_base) {
auto byte_ptr = builder->CreateBitCast(
byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context));
builder->CreateStore(
byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(0)}));
}
// 4. store `offset` in `bit_ptr_struct` (if provided)
if (bit_offset) {
builder->CreateStore(
bit_offset,
builder->CreateGEP(bit_ptr_struct,
{tlctx->get_constant(0), tlctx->get_constant(1)}));
}
return bit_ptr_struct;
auto bit_ptr = create_entry_block_alloca(struct_type);
// 3. store `byte_ptr`
builder->CreateStore(
byte_ptr, builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
// 4. store `bit_offset
builder->CreateStore(bit_offset,
builder->CreateGEP(bit_ptr, {tlctx->get_constant(0),
tlctx->get_constant(1)}));
return bit_ptr;
}

std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_ptr(
llvm::Value *bit_ptr) {
auto byte_ptr = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto bit_offset = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
return std::make_tuple(byte_ptr, bit_offset);
}

llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *input_bit_ptr,
llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr,
int bit_offset_delta) {
auto byte_ptr_base = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto input_offset = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr);
auto new_bit_offset =
builder->CreateAdd(input_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr_struct(byte_ptr_base, new_bit_offset);
}

std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_pointer(
llvm::Value *ptr) {
// 1. load byte pointer
auto byte_ptr_in_bit_struct =
builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(0)});
auto byte_ptr = builder->CreateLoad(byte_ptr_in_bit_struct);
TI_ASSERT(byte_ptr->getType()->getPointerElementType()->isIntegerTy(8));

// 2. load bit offset
auto bit_offset_in_bit_struct =
builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(1)});
auto bit_offset = builder->CreateLoad(bit_offset_in_bit_struct);
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
return std::make_tuple(byte_ptr, bit_offset);
builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr(byte_ptr, new_bit_offset);
}

void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
Expand All @@ -1560,7 +1538,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
snode->dt->as<BitArrayType>()->get_element_num_bits();
auto offset = tlctx->get_constant(element_num_bits);
offset = builder->CreateMul(offset, llvm_val[stmt->input_index]);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_snode], offset);
} else {
TI_INFO(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand All @@ -1575,7 +1553,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) {
auto bit_offset = bit_struct->get_member_bit_offset(
stmt->input_snode->child_id(stmt->output_snode));
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
} else {
auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(),
{builder->CreateBitCast(
Expand Down
32 changes: 7 additions & 25 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(SNodeOpStmt *stmt) override;

llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt,
QuantFixedType *qfxt,
Type *physical_type);
llvm::Value *atomic_add_quant_fixed(AtomicOpStmt *stmt, QuantFixedType *qfxt);

llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt,
QuantIntType *qit,
Type *physical_type);
llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit);

llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt,
QuantIntType *qit,
Expand All @@ -252,20 +248,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void store_quant_int(llvm::Value *bit_ptr,
QuantIntType *qit,
Type *physical_type,
llvm::Value *value,
bool atomic);

void store_quant_int(llvm::Value *byte_ptr,
llvm::Value *bit_offset,
QuantIntType *qit,
Type *physical_type,
llvm::Value *value,
bool atomic);

void store_masked(llvm::Value *byte_ptr,
uint64 mask,
Type *physical_type,
llvm::Value *value,
bool atomic);

Expand All @@ -282,22 +269,18 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *extract_quant_float(llvm::Value *local_bit_struct,
SNode *digits_snode);

llvm::Value *load_quant_int(llvm::Value *ptr,
QuantIntType *qit,
Type *physical_type);
llvm::Value *load_quant_int(llvm::Value *ptr, QuantIntType *qit);

llvm::Value *extract_quant_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
QuantIntType *qit,
Type *physical_type);
QuantIntType *qit);

llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool shared_exponent);

llvm::Value *reconstruct_quant_float(llvm::Value *input_digits,
Expand All @@ -319,12 +302,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(IntegerOffsetStmt *stmt) override;

llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base = nullptr,
llvm::Value *bit_offset = nullptr);
llvm::Value *create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset);

llvm::Value *offset_bit_ptr(llvm::Value *input_bit_ptr, int bit_offset_delta);
std::tuple<llvm::Value *, llvm::Value *> load_bit_ptr(llvm::Value *bit_ptr);

std::tuple<llvm::Value *, llvm::Value *> load_bit_pointer(llvm::Value *ptr);
llvm::Value *offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta);

void visit(SNodeLookupStmt *stmt) override;

Expand Down
Loading

0 comments on commit 6357cf0

Please sign in to comment.