From 9b4a451bf8cd0b0ca9b3fb5e2806202c2d51dfa2 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Mon, 11 Jul 2022 14:04:09 +0800 Subject: [PATCH 1/2] [Lang] [type] Support placing QuantFixedType under quant_array --- taichi/codegen/cuda/codegen_cuda.cpp | 4 +-- taichi/codegen/llvm/codegen_llvm.cpp | 29 +++++++++++----------- taichi/codegen/llvm/codegen_llvm.h | 11 +++++--- taichi/codegen/llvm/codegen_llvm_quant.cpp | 22 ++++++++-------- taichi/ir/type.h | 10 +++++--- tests/python/test_quant_array.py | 24 ++++++++++++++++++ 6 files changed, 66 insertions(+), 34 deletions(-) diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index b35017ef795c3..199c59ae79850 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -560,9 +560,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { if (auto get_ch = stmt->src->cast()) { bool should_cache_as_read_only = current_offload->mem_access_opt.has_flag( get_ch->output_snode, SNodeAccessFlag::read_only); - global_load(stmt, should_cache_as_read_only); + create_global_load(stmt, should_cache_as_read_only); } else { - global_load(stmt, false); + create_global_load(stmt, false); } } diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 8fe7e6518f554..1b1b41c1ca3f4 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1394,19 +1394,18 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->dest->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto pointee_type = ptr_type->get_pointee_type(); - if (!pointee_type->is()) { - if (stmt->dest->as()->input_snode->type == - SNodeType::bit_struct) { - TI_ERROR( - "Bit struct stores with type {} should have been " - "handled by BitStructStoreStmt.", - pointee_type->to_string()); - } else { - TI_ERROR("Quant array only supports quant int type."); - } + if (stmt->dest->as()->input_snode->type == SNodeType::bit_struct) { + TI_ERROR( + "Bit struct stores with type {} should have been handled by BitStructStoreStmt.", + pointee_type->to_string()); + } + if (auto qit = pointee_type->cast()) { + store_quant_int(llvm_val[stmt->dest], qit, llvm_val[stmt->val], true); + } else if (auto qfxt = pointee_type->cast()) { + store_quant_fixed(llvm_val[stmt->dest], qfxt, llvm_val[stmt->val], true); + } else { + TI_NOT_IMPLEMENTED; } - store_quant_int(llvm_val[stmt->dest], pointee_type->as(), - llvm_val[stmt->val], true); } else { builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } @@ -1417,8 +1416,8 @@ llvm::Value *CodeGenLLVM::create_intrinsic_load(const DataType &dtype, TI_NOT_IMPLEMENTED; } -void CodeGenLLVM::global_load(GlobalLoadStmt *stmt, - bool should_cache_as_read_only) { +void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, + bool should_cache_as_read_only) { auto ptr = llvm_val[stmt->src]; auto ptr_type = stmt->src->ret_type->as(); if (ptr_type->is_bit_pointer()) { @@ -1449,7 +1448,7 @@ void CodeGenLLVM::global_load(GlobalLoadStmt *stmt, } void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { - global_load(stmt, false); + create_global_load(stmt, false); } void CodeGenLLVM::visit(ElementShuffleStmt *stmt){ diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index cb8241f60a51a..bb3a6efb34247 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -229,9 +229,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit); - llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt, - QuantIntType *qit, - llvm::Value *real); + llvm::Value *to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt); virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt); @@ -257,6 +255,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *value, bool atomic); + void store_quant_fixed(llvm::Value *bit_ptr, + QuantFixedType *qfxt, + llvm::Value *value, + bool atomic); + void store_masked(llvm::Value *byte_ptr, uint64 mask, llvm::Value *value, @@ -313,7 +316,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { QuantFloatType *qflt, bool shared_exponent); - void global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only); + void create_global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only); void visit(GlobalLoadStmt *stmt) override; diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 5b920a7c169dc..5ff0dc0f32650 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -34,7 +34,7 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]); auto physical_type = byte_ptr->getType()->getPointerElementType(); auto qit = qfxt->get_digits_type()->as(); - auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]); + auto val_store = to_quant_fixed(llvm_val[stmt->val], qfxt); val_store = builder->CreateSExt(val_store, physical_type); return create_call(fmt::format("atomic_add_partial_bits_b{}", physical_type->getIntegerBitWidth()), @@ -42,16 +42,10 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, tlctx->get_constant(qit->get_num_bits()), val_store}); } -llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, - QuantIntType *qit, - llvm::Value *real) { - llvm::Value *s = nullptr; - +llvm::Value *CodeGenLLVM::to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt) { // Compute int(real * (1.0 / scale) + 0.5) - auto s_numeric = 1.0 / qfxt->get_scale(); auto compute_type = qfxt->get_compute_type(); - s = builder->CreateFPCast(tlctx->get_constant(s_numeric), - llvm_type(compute_type)); + auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), llvm_type(compute_type)); auto input_real = builder->CreateFPCast(real, llvm_type(compute_type)); auto scaled = builder->CreateFMul(input_real, s); @@ -60,6 +54,7 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt, fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)), {scaled}); + auto qit = qfxt->get_digits_type()->as(); if (qit->get_is_signed()) { return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type())); } else { @@ -81,6 +76,13 @@ void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr, builder->CreateIntCast(value, physical_type, false)}); } +void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr, + QuantFixedType *qfxt, + llvm::Value *value, + bool atomic) { + store_quant_int(bit_ptr, qfxt->get_digits_type()->as(), to_quant_fixed(value, qfxt), atomic); +} + void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, uint64 mask, llvm::Value *value, @@ -120,7 +122,7 @@ llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val, QuantIntType *qit = nullptr; if (auto qfxt = input_type->cast()) { qit = qfxt->get_digits_type()->as(); - val = quant_fixed_to_quant_int(qfxt, qit, val); + val = to_quant_fixed(val, qfxt); } else { qit = input_type->as(); } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 9e4eb5b44f644..9ce1b03591bdd 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -306,9 +306,13 @@ class QuantArrayType : public Type { : physical_type_(physical_type), element_type_(element_type_), num_elements_(num_elements_) { - // TODO: avoid assertion? - TI_ASSERT(element_type_->is()); - element_num_bits_ = element_type_->as()->get_num_bits(); + if (auto qit = element_type_->cast()) { + element_num_bits_ = qit->get_num_bits(); + } else if (auto qfxt = element_type_->cast()) { + element_num_bits_ = qfxt->get_digits_type()->as()->get_num_bits(); + } else { + TI_ERROR("Quant array only supports quant int/fixed type for now."); + } } std::string to_string() const override; diff --git a/tests/python/test_quant_array.py b/tests/python/test_quant_array.py index cb075827588d7..75486153e1d51 100644 --- a/tests/python/test_quant_array.py +++ b/tests/python/test_quant_array.py @@ -43,6 +43,30 @@ def assign(): assign() +@test_utils.test(require=ti.extension.quant, debug=True) +def test_1D_quant_array_fixed(): + qfxt = ti.types.quant.fixed(frac=8, range=2) + + x = ti.field(dtype=qfxt) + + N = 4 + + ti.root.quant_array(ti.i, N, num_bits=32).place(x) + + @ti.kernel + def set_val(): + for i in range(N): + x[i] = i * 0.5 + + @ti.kernel + def verify_val(): + for i in range(N): + assert x[i] == i * 0.5 + + set_val() + verify_val() + + @test_utils.test(require=ti.extension.quant, debug=True) def test_2D_quant_array(): qu1 = ti.types.quant.int(1, False) From 1e8b0f32b016d010e34fb4bbfc5c257266404b28 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Jul 2022 06:07:02 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 6 ++++-- taichi/codegen/llvm/codegen_llvm_quant.cpp | 9 ++++++--- taichi/ir/type.h | 3 ++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 1b1b41c1ca3f4..9f822e201f7c6 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1394,9 +1394,11 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->dest->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto pointee_type = ptr_type->get_pointee_type(); - if (stmt->dest->as()->input_snode->type == SNodeType::bit_struct) { + if (stmt->dest->as()->input_snode->type == + SNodeType::bit_struct) { TI_ERROR( - "Bit struct stores with type {} should have been handled by BitStructStoreStmt.", + "Bit struct stores with type {} should have been handled by " + "BitStructStoreStmt.", pointee_type->to_string()); } if (auto qit = pointee_type->cast()) { diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 5ff0dc0f32650..b916c4f243f94 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -42,10 +42,12 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt, tlctx->get_constant(qit->get_num_bits()), val_store}); } -llvm::Value *CodeGenLLVM::to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt) { +llvm::Value *CodeGenLLVM::to_quant_fixed(llvm::Value *real, + QuantFixedType *qfxt) { // Compute int(real * (1.0 / scale) + 0.5) auto compute_type = qfxt->get_compute_type(); - auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), llvm_type(compute_type)); + auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), + llvm_type(compute_type)); auto input_real = builder->CreateFPCast(real, llvm_type(compute_type)); auto scaled = builder->CreateFMul(input_real, s); @@ -80,7 +82,8 @@ void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr, QuantFixedType *qfxt, llvm::Value *value, bool atomic) { - store_quant_int(bit_ptr, qfxt->get_digits_type()->as(), to_quant_fixed(value, qfxt), atomic); + store_quant_int(bit_ptr, qfxt->get_digits_type()->as(), + to_quant_fixed(value, qfxt), atomic); } void CodeGenLLVM::store_masked(llvm::Value *byte_ptr, diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 9ce1b03591bdd..3863b0fee6e65 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -309,7 +309,8 @@ class QuantArrayType : public Type { if (auto qit = element_type_->cast()) { element_num_bits_ = qit->get_num_bits(); } else if (auto qfxt = element_type_->cast()) { - element_num_bits_ = qfxt->get_digits_type()->as()->get_num_bits(); + element_num_bits_ = + qfxt->get_digits_type()->as()->get_num_bits(); } else { TI_ERROR("Quant array only supports quant int/fixed type for now."); }