Skip to content

Commit

Permalink
[type] [refactor] Decouple quant from SNode 5/n: Rewrite load_quant_f…
Browse files Browse the repository at this point in the history
…loat() without SNode (#5422)

* [type] [refactor] Rewrite load_quant_float() without SNode

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jul 15, 2022
1 parent e175791 commit d3184b4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
5 changes: 3 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,9 +1437,10 @@ void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
load_quant_fixed(ptr, qfxt, physical_type, should_cache_as_read_only);
} else {
TI_ASSERT(val_type->is<QuantFloatType>());
TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>());
llvm_val[stmt] = load_quant_float(
ptr, get_ch->output_snode, val_type->as<QuantFloatType>(),
physical_type, should_cache_as_read_only);
ptr, get_ch->input_snode->dt->as<BitStructType>(),
get_ch->output_snode->id_in_bit_struct, should_cache_as_read_only);
}
} else {
// Byte pointer case.
Expand Down
11 changes: 5 additions & 6 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
SNode *digits_snode,
QuantFloatType *qflt,
Type *physical_type,
llvm::Value *load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
llvm::Value *load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
Expand Down
37 changes: 19 additions & 18 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,32 +561,33 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
return builder->CreateFMul(cast, s);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr,
SNode *digits_snode,
QuantFloatType *qflt,
Type *physical_type,
llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only) {
auto exponent_snode = digits_snode->exp_snode;
// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr = offset_bit_ptr(
digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, physical_type,
should_cache_as_read_only,
digits_snode->owns_shared_exponent);
auto exponent_id = bit_struct->get_member_exponent(digits_id);
auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id);
auto bit_offset_delta = exponent_bit_offset - digits_bit_offset;
auto exponent_ptr = offset_bit_ptr(digits_ptr, bit_offset_delta);
auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
auto physical_type = bit_struct->get_physical_type();
auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id);
return load_quant_float(digits_ptr, exponent_ptr, qflt, physical_type,
should_cache_as_read_only, shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
bool shared_exponent) {
auto digits = load_quant_int(digits_bit_ptr,
qflt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
auto digits =
load_quant_int(digits_ptr, qflt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
auto exponent_val = load_quant_int(
exponent_bit_ptr, qflt->get_exponent_type()->as<QuantIntType>(),
exponent_ptr, qflt->get_exponent_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent);
}
Expand Down

0 comments on commit d3184b4

Please sign in to comment.