-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[type] Local adder structure #2136
Changes from 4 commits
6c9ad1d
e487acd
d7840f3
9daee57
9dcad62
6d5a002
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -16,6 +16,7 @@ class BitLoopVectorize : public IRVisitor { | |||
bool in_struct_for_loop; | ||||
StructForStmt *loop_stmt; | ||||
PrimitiveType *bit_array_physical_type; | ||||
std::unordered_map<Stmt *, std::vector<Stmt *>> transformed_atomics; | ||||
|
||||
BitLoopVectorize() { | ||||
allow_undefined_visitor = true; | ||||
|
@@ -155,17 +156,180 @@ class BitLoopVectorize : public IRVisitor { | |||
bit_array_physical_type = nullptr; | ||||
} | ||||
|
||||
void visit(BinaryOpStmt *stmt) override { | ||||
// vectorize cmp_eq and bit_and between | ||||
// vectorized data(local adder/array elems) and constant | ||||
if (in_struct_for_loop && bit_vectorize != 1) { | ||||
if (stmt->op_type == BinaryOpType::bit_and) { | ||||
// if the rhs is a bit vectorized stmt and lhs is a const 1 | ||||
// (usually generated by boolean expr), we simply replace | ||||
// the stmt with its rhs | ||||
int lhs_val = get_constant_value(stmt->lhs); | ||||
if (lhs_val == 1) { | ||||
if (auto rhs = stmt->rhs->cast<BinaryOpStmt>(); | ||||
rhs && rhs->is_bit_vectorized) { | ||||
stmt->replace_with(stmt->rhs); | ||||
} | ||||
} | ||||
} else if (stmt->op_type == BinaryOpType::cmp_eq) { | ||||
if (auto lhs = stmt->lhs->cast<GlobalLoadStmt>()) { | ||||
// case 0: lhs is a vectorized global load from the bit array | ||||
if (auto ptr = lhs->ptr->cast<GlobalPtrStmt>(); | ||||
ptr && ptr->is_bit_vectorized) { | ||||
int32 rhs_val = get_constant_value(stmt->rhs); | ||||
// TODO: we limit 1 for now, 0 should be easy to implement by a | ||||
// bit_not on original bit pattern | ||||
TI_ASSERT(rhs_val == 1); | ||||
// cmp_eq with 1 yields the bit pattern itself | ||||
|
||||
// to pass CFG analysis and mark the stmt vectorized | ||||
// create a dummy lhs + 0 here | ||||
auto zero = std::make_unique<ConstStmt>(TypedConstant(0)); | ||||
auto add = std::make_unique<BinaryOpStmt>(BinaryOpType::add, | ||||
stmt->lhs, zero.get()); | ||||
add->is_bit_vectorized = true; | ||||
// modify IR | ||||
auto zero_p = zero.get(); | ||||
stmt->insert_before_me(std::move(zero)); | ||||
stmt->replace_with(add.get()); | ||||
zero_p->insert_after_me(std::move(add)); | ||||
} | ||||
} else if (auto lhs = stmt->lhs->cast<LocalLoadStmt>()) { | ||||
// case 1: lhs is a local load from a local adder structure | ||||
auto it = transformed_atomics.find(lhs->ptr[0].var); | ||||
if (it != transformed_atomics.end()) { | ||||
int32 rhs_val = get_constant_value(stmt->rhs); | ||||
// TODO: we limit 2 and 3 for now, the other case should be | ||||
// implement in a similar fashion | ||||
TI_ASSERT(rhs_val == 2 || rhs_val == 3); | ||||
// 010 and 011 respectively | ||||
auto &buffer_vec = it->second; | ||||
Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2]; | ||||
// load all three buffers | ||||
auto load_a = std::make_unique<LocalLoadStmt>(LocalAddress(a, 0)); | ||||
auto load_b = std::make_unique<LocalLoadStmt>(LocalAddress(b, 0)); | ||||
auto load_c = std::make_unique<LocalLoadStmt>(LocalAddress(c, 0)); | ||||
// compute not_a first | ||||
auto not_a = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not, | ||||
load_a.get()); | ||||
// b should always be itself so do nothing | ||||
// compute not_c | ||||
auto not_c = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not, | ||||
load_c.get()); | ||||
// bit_and all three patterns | ||||
auto and_a_b = std::make_unique<BinaryOpStmt>( | ||||
BinaryOpType::bit_and, not_a.get(), load_b.get()); | ||||
auto and_b_c = std::make_unique<BinaryOpStmt>( | ||||
BinaryOpType::bit_and, and_a_b.get(), | ||||
rhs_val == 2 ? (Stmt *)(not_c.get()) : (Stmt *)(load_c.get())); | ||||
// mark the last stmt as vectorized | ||||
and_b_c->is_bit_vectorized = true; | ||||
// modify IR | ||||
auto and_a_b_p = and_a_b.get(); | ||||
stmt->insert_before_me(std::move(load_a)); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A more elegant way to do this:
Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find to do it the elegant way we may still need |
||||
stmt->insert_before_me(std::move(load_b)); | ||||
stmt->insert_before_me(std::move(load_c)); | ||||
stmt->insert_before_me(std::move(not_a)); | ||||
stmt->insert_before_me(std::move(not_c)); | ||||
stmt->insert_before_me(std::move(and_a_b)); | ||||
stmt->replace_with(and_b_c.get()); | ||||
and_a_b_p->insert_after_me(std::move(and_b_c)); | ||||
} | ||||
} | ||||
} | ||||
} | ||||
} | ||||
|
||||
void visit(AtomicOpStmt *stmt) override { | ||||
DataType dt(bit_array_physical_type); | ||||
if (in_struct_for_loop && bit_vectorize != 1 && | ||||
stmt->op_type == AtomicOpType::add) { | ||||
auto it = transformed_atomics.find(stmt->dest); | ||||
// process a transformed atomic stmt | ||||
if (it != transformed_atomics.end()) { | ||||
auto &buffer_vec = it->second; | ||||
transform_atomic_add(buffer_vec, stmt, dt); | ||||
} else { | ||||
// alloc three buffers a, b, c | ||||
auto alloc_a = std::make_unique<AllocaStmt>(dt); | ||||
auto alloc_b = std::make_unique<AllocaStmt>(dt); | ||||
auto alloc_c = std::make_unique<AllocaStmt>(dt); | ||||
std::vector<Stmt *> buffer_vec{alloc_a.get(), alloc_b.get(), | ||||
alloc_c.get()}; | ||||
transformed_atomics[stmt->dest] = buffer_vec; | ||||
// modify IR | ||||
stmt->insert_before_me(std::move(alloc_a)); | ||||
stmt->insert_before_me(std::move(alloc_b)); | ||||
stmt->insert_before_me(std::move(alloc_c)); | ||||
transform_atomic_add(buffer_vec, stmt, dt); | ||||
} | ||||
} | ||||
} | ||||
|
||||
static void run(IRNode *node) { | ||||
BitLoopVectorize inst; | ||||
node->accept(&inst); | ||||
} | ||||
|
||||
private: | ||||
void transform_atomic_add(const std::vector<Stmt *> &buffer_vec, | ||||
AtomicOpStmt *stmt, | ||||
DataType &dt) { | ||||
// To transform an atomic add on a vectorized subarray of a bit array, | ||||
// we use a local adder with three buffers(*a*,*b*,*c*) of the same physical | ||||
// type of the original bit array. Each bit in *a* represents the highest | ||||
// bit of the result, while *b* for the second bit and *c* for the lowest | ||||
// bit To add *d* to the subarray, we do bit_xor and bit_and to compute the | ||||
// sum and the carry | ||||
Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2]; | ||||
auto load_c = std::make_unique<LocalLoadStmt>(LocalAddress(c, 0)); | ||||
auto carry_c = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and, | ||||
load_c.get(), stmt->val); | ||||
auto sum_c = | ||||
std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, c, stmt->val); | ||||
auto load_b = std::make_unique<LocalLoadStmt>(LocalAddress(b, 0)); | ||||
auto carry_b = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and, | ||||
load_b.get(), carry_c.get()); | ||||
auto sum_b = | ||||
std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, b, carry_c.get()); | ||||
// for a, we do not need to compute its carry | ||||
auto sum_a = | ||||
std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, a, carry_b.get()); | ||||
// modify IR | ||||
stmt->insert_before_me(std::move(load_c)); | ||||
stmt->insert_before_me(std::move(carry_c)); | ||||
stmt->insert_before_me(std::move(sum_c)); | ||||
stmt->insert_before_me(std::move(load_b)); | ||||
stmt->insert_before_me(std::move(carry_b)); | ||||
stmt->insert_before_me(std::move(sum_b)); | ||||
stmt->insert_before_me(std::move(sum_a)); | ||||
Comment on lines
+299
to
+305
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. |
||||
// there is no need to replace the stmt here as we | ||||
// will replace it manually later | ||||
} | ||||
|
||||
int32 get_constant_value(Stmt *stmt) { | ||||
int32 val = -1; | ||||
// the stmt could be a cast stmt | ||||
if (auto cast_stmt = stmt->cast<UnaryOpStmt>(); | ||||
cast_stmt && cast_stmt->is_cast() && | ||||
cast_stmt->op_type == UnaryOpType::cast_value) { | ||||
stmt = cast_stmt->operand; | ||||
} | ||||
if (auto constant_stmt = stmt->cast<ConstStmt>(); | ||||
constant_stmt && | ||||
constant_stmt->val[0].dt->is_primitive(PrimitiveTypeID::i32)) { | ||||
val = constant_stmt->val[0].val_i32; | ||||
} | ||||
return val; | ||||
} | ||||
}; | ||||
|
||||
namespace irpass { | ||||
|
||||
void bit_loop_vectorize(IRNode *root) { | ||||
TI_AUTO_PROF; | ||||
return BitLoopVectorize::run(root); | ||||
BitLoopVectorize::run(root); | ||||
die(root); | ||||
} | ||||
|
||||
} // namespace irpass | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little too much intrusion into the existing system. It seems to me that
is_bit_vectorized
is only used in thebit_loop_vectorize
pass - maybe you can use anstd::unordered_map<Stmt *, bool>
member variable inclass BitLoopVectorize
, instead of adding a new field inclass BinaryOpStmt
? (Just likellvm_val
in LLVM codegens.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_bit_vectorized
is used only in thebit_loop_ vectorize
pass when tagged onBinaryOpStmt
and this part should be replaced with some pass-scope data structure, just as you suggested. But forGlobalPtrStmt
andGetChStmt
, they need the tag to pass later passes includinglower_access
andtype_check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that's what I meant: we can use a pass-scope data structure just for
BinaryOpStmt::is_bit_vectorized
. Given we are rushing for the deadline it's fine that we don't do it now.