From ad7e1661433216c851bcaf0200ef74475b61913f Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 24 Nov 2022 09:33:15 +0800 Subject: [PATCH] [opt] Turn mod/div into bit_and/bit_shr if possible in packed mode (#6718) Issue: #6660 ### Brief Summary After this PR, the two main goals in #6660, > - Make the runtime overhead of SNode access always the same for both modes if the limitation is obeyed. > - Make the runtime overhead of struct for on a SNode whose path to root contains no non-power-of-two division always the same for both modes. are both achieved. See comments in the code for implementation details. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/codegen/llvm/struct_llvm.cpp | 4 +++- taichi/transforms/utils.cpp | 9 +++++++++ taichi/transforms/utils.h | 3 +++ tests/cpp/transforms/scalar_pointer_lowerer_test.cpp | 11 +++++++---- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/taichi/codegen/llvm/struct_llvm.cpp b/taichi/codegen/llvm/struct_llvm.cpp index 27caf4ca65ba6..6a7657c375848 100644 --- a/taichi/codegen/llvm/struct_llvm.cpp +++ b/taichi/codegen/llvm/struct_llvm.cpp @@ -175,7 +175,9 @@ void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { auto prev = tlctx_->get_constant(snode->extractors[i].acc_shape * snode->extractors[i].shape); auto next = tlctx_->get_constant(snode->extractors[i].acc_shape); - addition = builder.CreateSDiv(builder.CreateSRem(l, prev), next); + // Use UDiv/URem instead of SDiv/SRem so that LLVM can optimize them + // into bitwise operations when the divisor is a power of two. + addition = builder.CreateUDiv(builder.CreateURem(l, prev), next); } auto in = call(&builder, "PhysicalCoordinates_get_val", inp_coords, tlctx_->get_constant(i)); diff --git a/taichi/transforms/utils.cpp b/taichi/transforms/utils.cpp index da231e856698f..13fd326155f4b 100644 --- a/taichi/transforms/utils.cpp +++ b/taichi/transforms/utils.cpp @@ -3,11 +3,20 @@ namespace taichi::lang { Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) { + if (bit::is_power_of_two(y)) { + auto const_stmt = stmts->push_back(TypedConstant(y - 1)); + return stmts->push_back(BinaryOpType::bit_and, x, const_stmt); + } auto const_stmt = stmts->push_back(TypedConstant(y)); return stmts->push_back(BinaryOpType::mod, x, const_stmt); } Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) { + if (bit::is_power_of_two(y)) { + auto const_stmt = + stmts->push_back(TypedConstant(bit::log2int(y))); + return stmts->push_back(BinaryOpType::bit_shr, x, const_stmt); + } auto const_stmt = stmts->push_back(TypedConstant(y)); return stmts->push_back(BinaryOpType::div, x, const_stmt); } diff --git a/taichi/transforms/utils.h b/taichi/transforms/utils.h index 6787364b9bf8d..470c0de53d8fb 100644 --- a/taichi/transforms/utils.h +++ b/taichi/transforms/utils.h @@ -2,6 +2,9 @@ namespace taichi::lang { +// These two helper functions are targeting cases where x is assumed +// non-negative but with a signed type so no automatic transformation to bitwise +// operations can be applied in other compiler passes. Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y); Stmt *generate_div(VecStatement *stmts, Stmt *x, int y); diff --git a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp index ea6b20c197e6e..4d5935d0a408d 100644 --- a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp +++ b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp @@ -103,7 +103,7 @@ TEST_F(ScalarPointerLowererTest, Basic) { } } -TEST(ScalarPointerLowerer, EliminateMod) { +TEST(ScalarPointerLowerer, EliminateModDiv) { const bool kPacked = true; IRBuilder builder; VecStatement lowered; @@ -111,7 +111,8 @@ TEST(ScalarPointerLowerer, EliminateMod) { auto root = std::make_unique(/*depth=*/0, SNodeType::root); SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, kPacked, "")); SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, kPacked, "")); - SNode *dense_3 = &(dense_2->dense({Axis{0}}, /*size=*/5, kPacked, "")); + SNode *dense_3 = + &(dense_2->dense({Axis{0}, Axis{1}}, /*size=*/{5, 8}, kPacked, "")); SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place)); SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place)); LowererImpl lowerer_1{leaf_1, @@ -129,8 +130,10 @@ TEST(ScalarPointerLowerer, EliminateMod) { kPacked}; lowerer_2.run(); for (int i = 0; i < lowered.size(); i++) { - ASSERT_FALSE(lowered[i]->is() && - lowered[i]->as()->op_type == BinaryOpType::mod); + ASSERT_FALSE( + lowered[i]->is() && + (lowered[i]->as()->op_type == BinaryOpType::mod || + lowered[i]->as()->op_type == BinaryOpType::div)); } } } // namespace