From 2dda83e5ca8d512db9dc882d310de59fafe8f3aa Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 26 Jul 2022 15:21:34 +0800 Subject: [PATCH] [type] [refactor] Decouple quant from SNode 9/n: Remove exponent handling from SNode (#5510) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/analysis/offline_cache_util.cpp | 16 ---- taichi/codegen/llvm/codegen_llvm.cpp | 4 +- taichi/codegen/metal/codegen_metal.cpp | 2 +- taichi/ir/snode.cpp | 14 +--- taichi/ir/snode.h | 9 +-- taichi/ir/type_utils.h | 75 +++++++++++++------ taichi/program/snode_expr_utils.cpp | 40 +--------- .../transforms/optimize_bit_struct_stores.cpp | 7 +- tests/python/test_bit_struct.py | 32 ++++++++ 9 files changed, 97 insertions(+), 102 deletions(-) diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 66b344b184c30..a6c6d26ad02ff 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -119,22 +119,6 @@ static void get_offline_cache_key_of_snode_impl( get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited); } } - if (snode->exp_snode) { - get_offline_cache_key_of_snode_impl(snode->exp_snode, serializer, visited); - } - serializer(snode->bit_offset); - serializer(snode->placing_shared_exp); - serializer(snode->owns_shared_exponent); - for (auto s : snode->exponent_users) { - get_offline_cache_key_of_snode_impl(s, serializer, visited); - } - if (snode->currently_placing_exp_snode) { - get_offline_cache_key_of_snode_impl(snode->currently_placing_exp_snode, - serializer, visited); - } - if (snode->currently_placing_exp_snode_dtype) { - serializer(snode->currently_placing_exp_snode_dtype->to_string()); - } serializer(snode->is_bit_level); serializer(snode->is_path_all_dense); serializer(snode->node_type_name); diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 429bfae82d38d..ee651d61deacf 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1691,8 +1691,8 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) { llvm_val[stmt] = llvm_val[stmt->input_ptr]; } else if (stmt->ret_type->as()->is_bit_pointer()) { auto bit_struct = stmt->input_snode->dt->cast(); - auto bit_offset = bit_struct->get_member_bit_offset( - stmt->input_snode->child_id(stmt->output_snode)); + auto bit_offset = + bit_struct->get_member_bit_offset(stmt->output_snode->id_in_bit_struct); auto offset = tlctx->get_constant(bit_offset); llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset); } else { diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index 22917e5e0df52..eaf321e7ff52c 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -355,7 +355,7 @@ class KernelCodegenImpl : public IRVisitor { TI_ASSERT(stmt->ret_type->as()->is_bit_pointer()); const auto *bit_struct_ty = in_snode->dt->cast(); const auto bit_offset = - bit_struct_ty->get_member_bit_offset(in_snode->child_id(out_snode)); + bit_struct_ty->get_member_bit_offset(out_snode->id_in_bit_struct); // stmt->input_ptr is the "base" member in the generated SNode struct. emit("SNodeBitPointer {}({}, /*offset=*/{});", stmt->raw_name(), stmt->input_ptr->raw_name(), bit_offset); diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index dd6065ed240ac..cf768cc42179e 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -265,9 +265,6 @@ void SNode::print() { fmt::print(" "); } fmt::print("{}", get_node_type_name_hinted()); - if (exp_snode) { - fmt::print(" exp={}", exp_snode->get_node_type_name()); - } fmt::print("\n"); for (auto &c : ch) { c->print(); @@ -289,16 +286,13 @@ bool SNode::need_activation() const { } void SNode::begin_shared_exp_placement() { - TI_ASSERT(!placing_shared_exp); - TI_ASSERT(currently_placing_exp_snode == nullptr); - placing_shared_exp = true; + TI_ASSERT(bit_struct_type_builder); + bit_struct_type_builder->begin_placing_shared_exponent(); } void SNode::end_shared_exp_placement() { - TI_ASSERT(placing_shared_exp); - TI_ASSERT(currently_placing_exp_snode != nullptr); - currently_placing_exp_snode = nullptr; - placing_shared_exp = false; + TI_ASSERT(bit_struct_type_builder); + bit_struct_type_builder->end_placing_shared_exponent(); } bool SNode::is_primal() const { diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 5f5ad6c3831a1..2cf060869375d 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -137,14 +137,7 @@ class SNode { std::unique_ptr grad_info{nullptr}; std::unique_ptr bit_struct_type_builder{nullptr}; - SNode *exp_snode{nullptr}; // for QuantFloatType - int bit_offset{0}; // for children of bit_struct only - int id_in_bit_struct{0}; // for children of bit_struct only - bool placing_shared_exp{false}; - SNode *currently_placing_exp_snode{nullptr}; - Type *currently_placing_exp_snode_dtype{nullptr}; - bool owns_shared_exponent{false}; - std::vector exponent_users; + int id_in_bit_struct{0}; // for children of bit_struct only // is_bit_level=false: the SNode is not bitpacked // is_bit_level=true: the SNode is bitpacked (i.e., strictly inside bit_struct diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 32898f16c9382..c389cb3ad8810 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -185,7 +185,55 @@ class BitStructTypeBuilder { : physical_type_(physical_type) { } - std::tuple add_member(Type *member_type) { + int add_member(Type *member_type) { + if (auto qflt = member_type->cast()) { + auto exponent_type = qflt->get_exponent_type(); + auto exponent_id = -1; + if (is_placing_shared_exponent_ && current_shared_exponent_ != -1) { + // Reuse existing exponent + TI_ASSERT_INFO(member_types_[current_shared_exponent_] == exponent_type, + "QuantFloatTypes with shared exponents must have " + "exactly the same exponent type."); + exponent_id = current_shared_exponent_; + } else { + exponent_id = add_member_impl(exponent_type); + if (is_placing_shared_exponent_) { + current_shared_exponent_ = exponent_id; + } + } + auto digits_id = add_member_impl(member_type); + if (is_placing_shared_exponent_) { + member_owns_shared_exponents_[digits_id] = true; + } + member_exponents_[digits_id] = exponent_id; + member_exponent_users_[exponent_id].push_back(digits_id); + return digits_id; + } + return add_member_impl(member_type); + } + + void begin_placing_shared_exponent() { + TI_ASSERT(!is_placing_shared_exponent_); + TI_ASSERT(current_shared_exponent_ == -1); + is_placing_shared_exponent_ = true; + } + + void end_placing_shared_exponent() { + TI_ASSERT(is_placing_shared_exponent_); + TI_ASSERT(current_shared_exponent_ != -1); + current_shared_exponent_ = -1; + is_placing_shared_exponent_ = false; + } + + Type *build() const { + return TypeFactory::get_instance().get_bit_struct_type( + physical_type_, member_types_, member_bit_offsets_, + member_owns_shared_exponents_, member_exponents_, + member_exponent_users_); + } + + private: + int add_member_impl(Type *member_type) { int old_num_members = member_types_.size(); member_types_.push_back(member_type); member_bit_offsets_.push_back(member_total_bits_); @@ -202,35 +250,14 @@ class BitStructTypeBuilder { } else { TI_ERROR("Only a QuantType can be a member of a BitStructType."); } - auto old_member_total_bits = member_total_bits_; member_total_bits_ += member_qit->get_num_bits(); auto physical_bits = data_type_bits(physical_type_); TI_ERROR_IF(member_total_bits_ > physical_bits, "BitStructType overflows: {} bits used out of {}.", member_total_bits_, physical_bits); - return std::make_tuple(old_num_members, old_member_total_bits); + return old_num_members; } - void set_member_owns_shared_exponent(int id) { - member_owns_shared_exponents_[id] = true; - } - - void set_member_exponent(int id, int exponent_id) { - member_exponents_[id] = exponent_id; - } - - void add_member_exponent_user(int id, int user_id) { - member_exponent_users_[id].push_back(user_id); - } - - Type *build() const { - return TypeFactory::get_instance().get_bit_struct_type( - physical_type_, member_types_, member_bit_offsets_, - member_owns_shared_exponents_, member_exponents_, - member_exponent_users_); - } - - private: PrimitiveType *physical_type_{nullptr}; std::vector member_types_; std::vector member_bit_offsets_; @@ -238,6 +265,8 @@ class BitStructTypeBuilder { std::vector member_owns_shared_exponents_; std::vector member_exponents_; std::vector> member_exponent_users_; + bool is_placing_shared_exponent_{false}; + int current_shared_exponent_{-1}; }; } // namespace lang diff --git a/taichi/program/snode_expr_utils.cpp b/taichi/program/snode_expr_utils.cpp index f2c1b60f0c91b..87710a9346d79 100644 --- a/taichi/program/snode_expr_utils.cpp +++ b/taichi/program/snode_expr_utils.cpp @@ -51,31 +51,6 @@ void place_child(Expr *expr_arg, auto glb_var_expr = expr_arg->cast(); TI_ERROR_IF(glb_var_expr->snode != nullptr, "This variable has been placed."); - SNode *new_exp_snode = nullptr; - if (auto qflt = glb_var_expr->dt->cast()) { - auto exp = qflt->get_exponent_type(); - // Non-empty exponent type. First create a place SNode for the - // exponent value. - if (parent->placing_shared_exp && - parent->currently_placing_exp_snode != nullptr) { - // Reuse existing exponent - TI_ASSERT_INFO(parent->currently_placing_exp_snode_dtype == exp, - "QuantFloatTypes with shared exponents must have " - "exactly the same exponent type."); - new_exp_snode = parent->currently_placing_exp_snode; - } else { - auto &exp_node = parent->insert_children(SNodeType::place); - exp_node.dt = exp; - std::tie(exp_node.id_in_bit_struct, exp_node.bit_offset) = - parent->bit_struct_type_builder->add_member(exp); - exp_node.name = glb_var_expr->ident.raw_name() + "_exp"; - new_exp_snode = &exp_node; - if (parent->placing_shared_exp) { - parent->currently_placing_exp_snode = new_exp_snode; - parent->currently_placing_exp_snode_dtype = exp; - } - } - } auto &child = parent->insert_children(SNodeType::place); glb_var_expr->set_snode(&child); if (glb_var_expr->name == "") { @@ -92,21 +67,8 @@ void place_child(Expr *expr_arg, (*snode_to_exprs)[glb_var_expr->snode] = glb_var_expr; child.dt = glb_var_expr->dt; if (parent->bit_struct_type_builder) { - std::tie(child.id_in_bit_struct, child.bit_offset) = + child.id_in_bit_struct = parent->bit_struct_type_builder->add_member(child.dt); - if (parent->placing_shared_exp) { - child.owns_shared_exponent = true; - parent->bit_struct_type_builder->set_member_owns_shared_exponent( - child.id_in_bit_struct); - } - if (new_exp_snode) { - child.exp_snode = new_exp_snode; - parent->bit_struct_type_builder->set_member_exponent( - child.id_in_bit_struct, new_exp_snode->id_in_bit_struct); - new_exp_snode->exponent_users.push_back(&child); - parent->bit_struct_type_builder->add_member_exponent_user( - new_exp_snode->id_in_bit_struct, child.id_in_bit_struct); - } } if (!offset.empty()) child.set_index_offsets(offset); diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index 1993298717129..e6302e05ae5f6 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -32,9 +32,10 @@ class CreateBitStructStores : public BasicStmtVisitor { // We only handle bit_struct pointers here. - auto s = Stmt::make(get_ch->input_ptr, - std::vector{get_ch->chid}, - std::vector{stmt->val}); + auto s = Stmt::make( + get_ch->input_ptr, + std::vector{get_ch->output_snode->id_in_bit_struct}, + std::vector{stmt->val}); stmt->replace_with(VecStatement(std::move(s))); } }; diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 04403cd150106..58f17f784ac50 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -169,3 +169,35 @@ def assign(): assert x[i] == approx(i, abs=1e-3) else: assert x[i] == 0 + + +@test_utils.test(require=ti.extension.quant_basic, debug=True) +def test_multiple_types(): + f15 = ti.types.quant.float(exp=5, frac=10) + f18 = ti.types.quant.float(exp=5, frac=13) + u4 = ti.types.quant.int(bits=4, signed=False) + + p = ti.field(dtype=f15) + q = ti.field(dtype=f18) + r = ti.field(dtype=u4) + + cell = ti.root.dense(ti.i, 12).bit_struct(num_bits=32) + cell.place(p, q, shared_exponent=True) + cell.place(r) + + @ti.kernel + def set_val(): + for i in p: + p[i] = i * 3 + q[i] = i * 2 + r[i] = i + + @ti.kernel + def verify_val(): + for i in p: + assert p[i] == i * 3 + assert q[i] == i * 2 + assert r[i] == i + + set_val() + verify_val()