Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 22, 2022
1 parent bb1985b commit def861f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 24 deletions.
29 changes: 20 additions & 9 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1474,32 +1474,43 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) {

void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset) {
llvm::Value *CodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr,
llvm::Value *bit_offset) {
// 1. define the bit pointer struct (X=8/16/32/64)
// struct bit_pointer_X {
// iX* byte_ptr;
// i32 bit_offset;
// };
TI_ASSERT(bit_offset->getType()->isIntegerTy(32));
auto struct_type = llvm::StructType::get(*llvm_context, {byte_ptr->getType(), bit_offset->getType()});
auto struct_type = llvm::StructType::get(
*llvm_context, {byte_ptr->getType(), bit_offset->getType()});
// 2. allocate the bit pointer struct
auto bit_ptr = create_entry_block_alloca(struct_type);
// 3. store `byte_ptr`
builder->CreateStore(byte_ptr, builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
builder->CreateStore(
byte_ptr, builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
// 4. store `bit_offset
builder->CreateStore(bit_offset,builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
builder->CreateStore(bit_offset,
builder->CreateGEP(bit_ptr, {tlctx->get_constant(0),
tlctx->get_constant(1)}));
return bit_ptr;
}

std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_ptr(llvm::Value *bit_ptr) {
auto byte_ptr = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto bit_offset = builder->CreateLoad(builder->CreateGEP(bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_ptr(
llvm::Value *bit_ptr) {
auto byte_ptr = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto bit_offset = builder->CreateLoad(builder->CreateGEP(
bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
return std::make_tuple(byte_ptr, bit_offset);
}

llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta) {
llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr,
int bit_offset_delta) {
auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr);
auto new_bit_offset = builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta));
auto new_bit_offset =
builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr(byte_ptr, new_bit_offset);
}

Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *extract_quant_float(llvm::Value *local_bit_struct,
SNode *digits_snode);

llvm::Value *load_quant_int(llvm::Value *ptr,
QuantIntType *qit);
llvm::Value *load_quant_int(llvm::Value *ptr, QuantIntType *qit);

llvm::Value *extract_quant_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
Expand Down
35 changes: 22 additions & 13 deletions taichi/codegen/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_int(AtomicOpStmt *stmt,
QuantIntType *qit) {
auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]);
auto physical_type = byte_ptr->getType()->getPointerElementType();
return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), builder->CreateIntCast(llvm_val[stmt->val], physical_type, is_signed(stmt->val->ret_type))});
return create_call(
fmt::format("atomic_add_partial_bits_b{}",
physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()),
builder->CreateIntCast(llvm_val[stmt->val], physical_type,
is_signed(stmt->val->ret_type))});
}

llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt,
Expand All @@ -32,8 +36,10 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt,
auto qit = qfxt->get_digits_type()->as<QuantIntType>();
auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]);
val_store = builder->CreateSExt(val_store, physical_type);
return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()), val_store});
return create_call(fmt::format("atomic_add_partial_bits_b{}",
physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset,
tlctx->get_constant(qit->get_num_bits()), val_store});
}

llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt,
Expand Down Expand Up @@ -69,7 +75,8 @@ void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr,
auto physical_type = byte_ptr->getType()->getPointerElementType();
// TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers.
// Try to support 8/16-bit physical types.
create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()),
create_call(fmt::format("{}set_partial_bits_b{}", atomic ? "atomic_" : "",
physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset, tlctx->get_constant(qit->get_num_bits()),
builder->CreateIntCast(value, physical_type, false)});
}
Expand All @@ -89,8 +96,10 @@ void CodeGenLLVM::store_masked(llvm::Value *byte_ptr,
builder->CreateStore(value, byte_ptr);
return;
}
create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "", physical_type->getIntegerBitWidth()),
{byte_ptr, tlctx->get_constant(mask), builder->CreateIntCast(value, physical_type, false)});
create_call(fmt::format("{}set_mask_b{}", atomic ? "atomic_" : "",
physical_type->getIntegerBitWidth()),
{byte_ptr, tlctx->get_constant(mask),
builder->CreateIntCast(value, physical_type, false)});
}

llvm::Value *CodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
Expand Down Expand Up @@ -449,8 +458,7 @@ llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct,
digits_snode->owns_shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr,
QuantIntType *qit) {
llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr, QuantIntType *qit) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = builder->CreateLoad(byte_ptr);
return extract_quant_int(physical_value, bit_offset, qit);
Expand All @@ -467,9 +475,9 @@ llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
builder->CreateAdd(bit_offset, tlctx->get_constant(qit->get_num_bits()));
auto left = builder->CreateSub(
tlctx->get_constant(physical_type->getIntegerBitWidth()), bit_end);
auto right =
builder->CreateSub(tlctx->get_constant(physical_type->getIntegerBitWidth()),
tlctx->get_constant(qit->get_num_bits()));
auto right = builder->CreateSub(
tlctx->get_constant(physical_type->getIntegerBitWidth()),
tlctx->get_constant(qit->get_num_bits()));
left = builder->CreateIntCast(left, physical_type, false);
right = builder->CreateIntCast(right, physical_type, false);
auto step1 = builder->CreateShl(physical_value, left);
Expand Down Expand Up @@ -621,7 +629,8 @@ llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) {
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr = offset_bit_ptr(
digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, digits_snode->owns_shared_exponent);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt,
digits_snode->owns_shared_exponent);
} else {
auto qfxt = load_type->as<QuantFixedType>();
auto digits = load_quant_int(llvm_val[ptr],
Expand Down

0 comments on commit def861f

Please sign in to comment.