Skip to content

Commit

Permalink
[opt] Turn mod/div into bit_and/bit_shr if possible in packed mode (t…
Browse files Browse the repository at this point in the history
…aichi-dev#6718)

Issue: taichi-dev#6660

### Brief Summary

After this PR, the two main goals in taichi-dev#6660,

> - Make the runtime overhead of SNode access always the same for both
modes if the limitation is obeyed.
> - Make the runtime overhead of struct for on a SNode whose path to
root contains no non-power-of-two division always the same for both
modes.

are both achieved.

See comments in the code for implementation details.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 7b3f2b3 commit ad7e166
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
4 changes: 3 additions & 1 deletion taichi/codegen/llvm/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) {
auto prev = tlctx_->get_constant(snode->extractors[i].acc_shape *
snode->extractors[i].shape);
auto next = tlctx_->get_constant(snode->extractors[i].acc_shape);
addition = builder.CreateSDiv(builder.CreateSRem(l, prev), next);
// Use UDiv/URem instead of SDiv/SRem so that LLVM can optimize them
// into bitwise operations when the divisor is a power of two.
addition = builder.CreateUDiv(builder.CreateURem(l, prev), next);
}
auto in = call(&builder, "PhysicalCoordinates_get_val", inp_coords,
tlctx_->get_constant(i));
Expand Down
9 changes: 9 additions & 0 deletions taichi/transforms/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
namespace taichi::lang {

Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) {
if (bit::is_power_of_two(y)) {
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y - 1));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::bit_and, x, const_stmt);
}
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::mod, x, const_stmt);
}

Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) {
if (bit::is_power_of_two(y)) {
auto const_stmt =
stmts->push_back<ConstStmt>(TypedConstant(bit::log2int(y)));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::bit_shr, x, const_stmt);
}
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::div, x, const_stmt);
}
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

namespace taichi::lang {

// These two helper functions are targeting cases where x is assumed
// non-negative but with a signed type so no automatic transformation to bitwise
// operations can be applied in other compiler passes.
Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y);
Stmt *generate_div(VecStatement *stmts, Stmt *x, int y);

Expand Down
11 changes: 7 additions & 4 deletions tests/cpp/transforms/scalar_pointer_lowerer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,16 @@ TEST_F(ScalarPointerLowererTest, Basic) {
}
}

TEST(ScalarPointerLowerer, EliminateMod) {
TEST(ScalarPointerLowerer, EliminateModDiv) {
const bool kPacked = true;
IRBuilder builder;
VecStatement lowered;
Stmt *index = builder.get_int32(2);
auto root = std::make_unique<SNode>(/*depth=*/0, SNodeType::root);
SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, kPacked, ""));
SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, kPacked, ""));
SNode *dense_3 = &(dense_2->dense({Axis{0}}, /*size=*/5, kPacked, ""));
SNode *dense_3 =
&(dense_2->dense({Axis{0}, Axis{1}}, /*size=*/{5, 8}, kPacked, ""));
SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place));
SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place));
LowererImpl lowerer_1{leaf_1,
Expand All @@ -129,8 +130,10 @@ TEST(ScalarPointerLowerer, EliminateMod) {
kPacked};
lowerer_2.run();
for (int i = 0; i < lowered.size(); i++) {
ASSERT_FALSE(lowered[i]->is<BinaryOpStmt>() &&
lowered[i]->as<BinaryOpStmt>()->op_type == BinaryOpType::mod);
ASSERT_FALSE(
lowered[i]->is<BinaryOpStmt>() &&
(lowered[i]->as<BinaryOpStmt>()->op_type == BinaryOpType::mod ||
lowered[i]->as<BinaryOpStmt>()->op_type == BinaryOpType::div));
}
}
} // namespace
Expand Down

0 comments on commit ad7e166

Please sign in to comment.