From 8be56691bcb12408adcb6b1b15f8892b13ad7428 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 14 Jul 2022 10:58:52 +0800 Subject: [PATCH 1/2] [type] [refactor] Rewrite load_quant_float() without SNode --- taichi/codegen/llvm/codegen_llvm.cpp | 5 ++-- taichi/codegen/llvm/codegen_llvm.h | 11 ++++---- taichi/codegen/llvm/codegen_llvm_quant.cpp | 33 +++++++++++----------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 6966c81f44d84..be19d2628e38f 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1435,9 +1435,10 @@ void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, load_quant_fixed(ptr, qfxt, physical_type, should_cache_as_read_only); } else { TI_ASSERT(val_type->is()); + TI_ASSERT(get_ch->input_snode->dt->is()); llvm_val[stmt] = load_quant_float( - ptr, get_ch->output_snode, val_type->as(), - physical_type, should_cache_as_read_only); + ptr, get_ch->input_snode->dt->as(), + get_ch->output_snode->id_in_bit_struct, should_cache_as_read_only); } } else { // Byte pointer case. diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 9cb97fd056f49..bda24ab5434b8 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -266,14 +266,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *reconstruct_quant_fixed(llvm::Value *digits, QuantFixedType *qfxt); - llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr, - SNode *digits_snode, - QuantFloatType *qflt, - Type *physical_type, + llvm::Value *load_quant_float(llvm::Value *digits_ptr, + BitStructType *bit_struct, + int digits_id, bool should_cache_as_read_only); - llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr, - llvm::Value *exponent_bit_ptr, + llvm::Value *load_quant_float(llvm::Value *digits_ptr, + llvm::Value *exponent_ptr, QuantFloatType *qflt, Type *physical_type, bool should_cache_as_read_only, diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 32a442a31eccc..cd8ac616e605a 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -532,32 +532,33 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits, return builder->CreateFMul(cast, s); } -llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr, - SNode *digits_snode, - QuantFloatType *qflt, - Type *physical_type, +llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr, + BitStructType *bit_struct, + int digits_id, bool should_cache_as_read_only) { - auto exponent_snode = digits_snode->exp_snode; - // Compute the bit pointer of the exponent bits. - TI_ASSERT(digits_snode->parent == exponent_snode->parent); - auto exponent_bit_ptr = offset_bit_ptr( - digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset); - return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt, physical_type, - should_cache_as_read_only, - digits_snode->owns_shared_exponent); + auto exponent_id = bit_struct->get_member_exponent(digits_id); + auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id); + auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id); + auto bit_offset_delta = exponent_bit_offset - digits_bit_offset; + auto exponent_ptr = offset_bit_ptr(digits_ptr, bit_offset_delta); + auto qflt = bit_struct->get_member_type(digits_id)->as(); + auto physical_type = bit_struct->get_physical_type(); + auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id); + return load_quant_float(digits_ptr, exponent_ptr, qflt, physical_type, + should_cache_as_read_only, shared_exponent); } -llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr, - llvm::Value *exponent_bit_ptr, +llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr, + llvm::Value *exponent_ptr, QuantFloatType *qflt, Type *physical_type, bool should_cache_as_read_only, bool shared_exponent) { - auto digits = load_quant_int(digits_bit_ptr, + auto digits = load_quant_int(digits_ptr, qflt->get_digits_type()->as(), physical_type, should_cache_as_read_only); auto exponent_val = load_quant_int( - exponent_bit_ptr, qflt->get_exponent_type()->as(), + exponent_ptr, qflt->get_exponent_type()->as(), physical_type, should_cache_as_read_only); return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent); } From 98b9de3422d6605fc9783161d91c2b9538bc46d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Jul 2022 03:01:33 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm_quant.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index cd8ac616e605a..0212066970fbf 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -554,9 +554,9 @@ llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr, Type *physical_type, bool should_cache_as_read_only, bool shared_exponent) { - auto digits = load_quant_int(digits_ptr, - qflt->get_digits_type()->as(), - physical_type, should_cache_as_read_only); + auto digits = + load_quant_int(digits_ptr, qflt->get_digits_type()->as(), + physical_type, should_cache_as_read_only); auto exponent_val = load_quant_int( exponent_ptr, qflt->get_exponent_type()->as(), physical_type, should_cache_as_read_only);