Skip to content

Commit

Permalink
[type] [refactor] Decouple quant from SNode 9/n: Remove exponent hand…
Browse files Browse the repository at this point in the history
…ling from SNode (#5510)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Jul 26, 2022
1 parent 8835e2f commit 2dda83e
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 102 deletions.
16 changes: 0 additions & 16 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,6 @@ static void get_offline_cache_key_of_snode_impl(
get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited);
}
}
if (snode->exp_snode) {
get_offline_cache_key_of_snode_impl(snode->exp_snode, serializer, visited);
}
serializer(snode->bit_offset);
serializer(snode->placing_shared_exp);
serializer(snode->owns_shared_exponent);
for (auto s : snode->exponent_users) {
get_offline_cache_key_of_snode_impl(s, serializer, visited);
}
if (snode->currently_placing_exp_snode) {
get_offline_cache_key_of_snode_impl(snode->currently_placing_exp_snode,
serializer, visited);
}
if (snode->currently_placing_exp_snode_dtype) {
serializer(snode->currently_placing_exp_snode_dtype->to_string());
}
serializer(snode->is_bit_level);
serializer(snode->is_path_all_dense);
serializer(snode->node_type_name);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1691,8 +1691,8 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) {
llvm_val[stmt] = llvm_val[stmt->input_ptr];
} else if (stmt->ret_type->as<PointerType>()->is_bit_pointer()) {
auto bit_struct = stmt->input_snode->dt->cast<BitStructType>();
auto bit_offset = bit_struct->get_member_bit_offset(
stmt->input_snode->child_id(stmt->output_snode));
auto bit_offset =
bit_struct->get_member_bit_offset(stmt->output_snode->id_in_bit_struct);
auto offset = tlctx->get_constant(bit_offset);
llvm_val[stmt] = create_bit_ptr(llvm_val[stmt->input_ptr], offset);
} else {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class KernelCodegenImpl : public IRVisitor {
TI_ASSERT(stmt->ret_type->as<PointerType>()->is_bit_pointer());
const auto *bit_struct_ty = in_snode->dt->cast<BitStructType>();
const auto bit_offset =
bit_struct_ty->get_member_bit_offset(in_snode->child_id(out_snode));
bit_struct_ty->get_member_bit_offset(out_snode->id_in_bit_struct);
// stmt->input_ptr is the "base" member in the generated SNode struct.
emit("SNodeBitPointer {}({}, /*offset=*/{});", stmt->raw_name(),
stmt->input_ptr->raw_name(), bit_offset);
Expand Down
14 changes: 4 additions & 10 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,6 @@ void SNode::print() {
fmt::print(" ");
}
fmt::print("{}", get_node_type_name_hinted());
if (exp_snode) {
fmt::print(" exp={}", exp_snode->get_node_type_name());
}
fmt::print("\n");
for (auto &c : ch) {
c->print();
Expand All @@ -289,16 +286,13 @@ bool SNode::need_activation() const {
}

void SNode::begin_shared_exp_placement() {
TI_ASSERT(!placing_shared_exp);
TI_ASSERT(currently_placing_exp_snode == nullptr);
placing_shared_exp = true;
TI_ASSERT(bit_struct_type_builder);
bit_struct_type_builder->begin_placing_shared_exponent();
}

void SNode::end_shared_exp_placement() {
TI_ASSERT(placing_shared_exp);
TI_ASSERT(currently_placing_exp_snode != nullptr);
currently_placing_exp_snode = nullptr;
placing_shared_exp = false;
TI_ASSERT(bit_struct_type_builder);
bit_struct_type_builder->end_placing_shared_exponent();
}

bool SNode::is_primal() const {
Expand Down
9 changes: 1 addition & 8 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,7 @@ class SNode {
std::unique_ptr<GradInfoProvider> grad_info{nullptr};

std::unique_ptr<BitStructTypeBuilder> bit_struct_type_builder{nullptr};
SNode *exp_snode{nullptr}; // for QuantFloatType
int bit_offset{0}; // for children of bit_struct only
int id_in_bit_struct{0}; // for children of bit_struct only
bool placing_shared_exp{false};
SNode *currently_placing_exp_snode{nullptr};
Type *currently_placing_exp_snode_dtype{nullptr};
bool owns_shared_exponent{false};
std::vector<SNode *> exponent_users;
int id_in_bit_struct{0}; // for children of bit_struct only

// is_bit_level=false: the SNode is not bitpacked
// is_bit_level=true: the SNode is bitpacked (i.e., strictly inside bit_struct
Expand Down
75 changes: 52 additions & 23 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,55 @@ class BitStructTypeBuilder {
: physical_type_(physical_type) {
}

std::tuple<int, int> add_member(Type *member_type) {
int add_member(Type *member_type) {
if (auto qflt = member_type->cast<QuantFloatType>()) {
auto exponent_type = qflt->get_exponent_type();
auto exponent_id = -1;
if (is_placing_shared_exponent_ && current_shared_exponent_ != -1) {
// Reuse existing exponent
TI_ASSERT_INFO(member_types_[current_shared_exponent_] == exponent_type,
"QuantFloatTypes with shared exponents must have "
"exactly the same exponent type.");
exponent_id = current_shared_exponent_;
} else {
exponent_id = add_member_impl(exponent_type);
if (is_placing_shared_exponent_) {
current_shared_exponent_ = exponent_id;
}
}
auto digits_id = add_member_impl(member_type);
if (is_placing_shared_exponent_) {
member_owns_shared_exponents_[digits_id] = true;
}
member_exponents_[digits_id] = exponent_id;
member_exponent_users_[exponent_id].push_back(digits_id);
return digits_id;
}
return add_member_impl(member_type);
}

void begin_placing_shared_exponent() {
TI_ASSERT(!is_placing_shared_exponent_);
TI_ASSERT(current_shared_exponent_ == -1);
is_placing_shared_exponent_ = true;
}

void end_placing_shared_exponent() {
TI_ASSERT(is_placing_shared_exponent_);
TI_ASSERT(current_shared_exponent_ != -1);
current_shared_exponent_ = -1;
is_placing_shared_exponent_ = false;
}

Type *build() const {
return TypeFactory::get_instance().get_bit_struct_type(
physical_type_, member_types_, member_bit_offsets_,
member_owns_shared_exponents_, member_exponents_,
member_exponent_users_);
}

private:
int add_member_impl(Type *member_type) {
int old_num_members = member_types_.size();
member_types_.push_back(member_type);
member_bit_offsets_.push_back(member_total_bits_);
Expand All @@ -202,42 +250,23 @@ class BitStructTypeBuilder {
} else {
TI_ERROR("Only a QuantType can be a member of a BitStructType.");
}
auto old_member_total_bits = member_total_bits_;
member_total_bits_ += member_qit->get_num_bits();
auto physical_bits = data_type_bits(physical_type_);
TI_ERROR_IF(member_total_bits_ > physical_bits,
"BitStructType overflows: {} bits used out of {}.",
member_total_bits_, physical_bits);
return std::make_tuple(old_num_members, old_member_total_bits);
return old_num_members;
}

void set_member_owns_shared_exponent(int id) {
member_owns_shared_exponents_[id] = true;
}

void set_member_exponent(int id, int exponent_id) {
member_exponents_[id] = exponent_id;
}

void add_member_exponent_user(int id, int user_id) {
member_exponent_users_[id].push_back(user_id);
}

Type *build() const {
return TypeFactory::get_instance().get_bit_struct_type(
physical_type_, member_types_, member_bit_offsets_,
member_owns_shared_exponents_, member_exponents_,
member_exponent_users_);
}

private:
PrimitiveType *physical_type_{nullptr};
std::vector<Type *> member_types_;
std::vector<int> member_bit_offsets_;
int member_total_bits_{0};
std::vector<bool> member_owns_shared_exponents_;
std::vector<int> member_exponents_;
std::vector<std::vector<int>> member_exponent_users_;
bool is_placing_shared_exponent_{false};
int current_shared_exponent_{-1};
};

} // namespace lang
Expand Down
40 changes: 1 addition & 39 deletions taichi/program/snode_expr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,6 @@ void place_child(Expr *expr_arg,
auto glb_var_expr = expr_arg->cast<GlobalVariableExpression>();
TI_ERROR_IF(glb_var_expr->snode != nullptr,
"This variable has been placed.");
SNode *new_exp_snode = nullptr;
if (auto qflt = glb_var_expr->dt->cast<QuantFloatType>()) {
auto exp = qflt->get_exponent_type();
// Non-empty exponent type. First create a place SNode for the
// exponent value.
if (parent->placing_shared_exp &&
parent->currently_placing_exp_snode != nullptr) {
// Reuse existing exponent
TI_ASSERT_INFO(parent->currently_placing_exp_snode_dtype == exp,
"QuantFloatTypes with shared exponents must have "
"exactly the same exponent type.");
new_exp_snode = parent->currently_placing_exp_snode;
} else {
auto &exp_node = parent->insert_children(SNodeType::place);
exp_node.dt = exp;
std::tie(exp_node.id_in_bit_struct, exp_node.bit_offset) =
parent->bit_struct_type_builder->add_member(exp);
exp_node.name = glb_var_expr->ident.raw_name() + "_exp";
new_exp_snode = &exp_node;
if (parent->placing_shared_exp) {
parent->currently_placing_exp_snode = new_exp_snode;
parent->currently_placing_exp_snode_dtype = exp;
}
}
}
auto &child = parent->insert_children(SNodeType::place);
glb_var_expr->set_snode(&child);
if (glb_var_expr->name == "") {
Expand All @@ -92,21 +67,8 @@ void place_child(Expr *expr_arg,
(*snode_to_exprs)[glb_var_expr->snode] = glb_var_expr;
child.dt = glb_var_expr->dt;
if (parent->bit_struct_type_builder) {
std::tie(child.id_in_bit_struct, child.bit_offset) =
child.id_in_bit_struct =
parent->bit_struct_type_builder->add_member(child.dt);
if (parent->placing_shared_exp) {
child.owns_shared_exponent = true;
parent->bit_struct_type_builder->set_member_owns_shared_exponent(
child.id_in_bit_struct);
}
if (new_exp_snode) {
child.exp_snode = new_exp_snode;
parent->bit_struct_type_builder->set_member_exponent(
child.id_in_bit_struct, new_exp_snode->id_in_bit_struct);
new_exp_snode->exponent_users.push_back(&child);
parent->bit_struct_type_builder->add_member_exponent_user(
new_exp_snode->id_in_bit_struct, child.id_in_bit_struct);
}
}
if (!offset.empty())
child.set_index_offsets(offset);
Expand Down
7 changes: 4 additions & 3 deletions taichi/transforms/optimize_bit_struct_stores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class CreateBitStructStores : public BasicStmtVisitor {

// We only handle bit_struct pointers here.

auto s = Stmt::make<BitStructStoreStmt>(get_ch->input_ptr,
std::vector<int>{get_ch->chid},
std::vector<Stmt *>{stmt->val});
auto s = Stmt::make<BitStructStoreStmt>(
get_ch->input_ptr,
std::vector<int>{get_ch->output_snode->id_in_bit_struct},
std::vector<Stmt *>{stmt->val});
stmt->replace_with(VecStatement(std::move(s)));
}
};
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,35 @@ def assign():
assert x[i] == approx(i, abs=1e-3)
else:
assert x[i] == 0


@test_utils.test(require=ti.extension.quant_basic, debug=True)
def test_multiple_types():
f15 = ti.types.quant.float(exp=5, frac=10)
f18 = ti.types.quant.float(exp=5, frac=13)
u4 = ti.types.quant.int(bits=4, signed=False)

p = ti.field(dtype=f15)
q = ti.field(dtype=f18)
r = ti.field(dtype=u4)

cell = ti.root.dense(ti.i, 12).bit_struct(num_bits=32)
cell.place(p, q, shared_exponent=True)
cell.place(r)

@ti.kernel
def set_val():
for i in p:
p[i] = i * 3
q[i] = i * 2
r[i] = i

@ti.kernel
def verify_val():
for i in p:
assert p[i] == i * 3
assert q[i] == i * 2
assert r[i] == i

set_val()
verify_val()

0 comments on commit 2dda83e

Please sign in to comment.