From 7c1b8507e061748275731a3e304b188f037e9d32 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 23 Nov 2022 21:03:15 +0800 Subject: [PATCH] [opt] Eliminate redundant mod in demote_dense_struct_fors under packed mode (#6709) Issue: #6660 ### Brief Summary This PR applies the same optimization in #6444 to the `demote_dense_struct_fors` pass. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../transforms/demote_dense_struct_fors.cpp | 20 ++++++++++++++----- taichi/transforms/scalar_pointer_lowerer.cpp | 15 +++++++------- taichi/transforms/utils.cpp | 13 +++++++----- taichi/transforms/utils.h | 3 ++- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/taichi/transforms/demote_dense_struct_fors.cpp b/taichi/transforms/demote_dense_struct_fors.cpp index aa08a5c790396..3bf7dc2f73a82 100644 --- a/taichi/transforms/demote_dense_struct_fors.cpp +++ b/taichi/transforms/demote_dense_struct_fors.cpp @@ -68,15 +68,25 @@ void convert_to_range_for(OffloadedStmt *offloaded, bool packed) { if (packed) { // no dependence on POT for (int i = 0; i < (int)snodes.size(); i++) { auto snode = snodes[i]; - auto extracted = - generate_mod_x_div_y(&body_header, main_loop_var, total_n, - total_n / snode->num_cells_per_container); + Stmt *extracted = main_loop_var; + if (i != 0) { // first extraction doesn't need a mod + extracted = generate_mod(&body_header, extracted, total_n); + } total_n /= snode->num_cells_per_container; + extracted = generate_div(&body_header, extracted, total_n); + bool is_first_extraction = true; for (int j = 0; j < (int)physical_indices.size(); j++) { auto p = physical_indices[j]; auto ext = snode->extractors[p]; - auto index = generate_mod_x_div_y( - &body_header, extracted, ext.acc_shape * ext.shape, ext.acc_shape); + if (!ext.active) + continue; + Stmt *index = extracted; + if (is_first_extraction) { // first extraction doesn't need a mod + is_first_extraction = false; + } else { + index = generate_mod(&body_header, index, ext.acc_shape * ext.shape); + } + index = generate_div(&body_header, index, ext.acc_shape); total_shape[p] /= ext.shape; auto multiplier = body_header.push_back(TypedConstant(total_shape[p])); diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index 2e041fc10d0a5..978279196e75a 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -82,17 +82,16 @@ void ScalarPointerLowerer::run() { const int prev = total_shape[k]; total_shape[k] /= snode->extractors[k].shape; const int next = total_shape[k]; + // Upon first extraction on axis k, "indices_[k_]" is the user + // coordinate on axis k and "prev" is the total shape of axis k. + // Unless it is an invalid out-of-bound access, we can assume + // "indices_[k_] < prev" so we don't need a mod here. if (is_first_extraction[k]) { - // Upon first extraction on axis k, "indices_[k_]" is the user - // coordinate on axis k and "prev" is the total shape of axis k. - // Unless it is an invalid out-of-bound access, we can assume - // "indices_[k_] < prev" so we don't need a mod here. - auto const_next = lowered_->push_back(TypedConstant(next)); - extracted = lowered_->push_back( - BinaryOpType::div, indices_[k_], const_next); + extracted = indices_[k_]; } else { - extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); + extracted = generate_mod(lowered_, indices_[k_], prev); } + extracted = generate_div(lowered_, extracted, next); } else { const int end = start_bits[k]; start_bits[k] -= snode->extractors[k].num_bits; diff --git a/taichi/transforms/utils.cpp b/taichi/transforms/utils.cpp index 85cc9ba9142ee..da231e856698f 100644 --- a/taichi/transforms/utils.cpp +++ b/taichi/transforms/utils.cpp @@ -2,11 +2,14 @@ namespace taichi::lang { -Stmt *generate_mod_x_div_y(VecStatement *stmts, Stmt *num, int x, int y) { - auto const_x = stmts->push_back(TypedConstant(x)); - auto mod_x = stmts->push_back(BinaryOpType::mod, num, const_x); - auto const_y = stmts->push_back(TypedConstant(y)); - return stmts->push_back(BinaryOpType::div, mod_x, const_y); +Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) { + auto const_stmt = stmts->push_back(TypedConstant(y)); + return stmts->push_back(BinaryOpType::mod, x, const_stmt); +} + +Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) { + auto const_stmt = stmts->push_back(TypedConstant(y)); + return stmts->push_back(BinaryOpType::div, x, const_stmt); } } // namespace taichi::lang diff --git a/taichi/transforms/utils.h b/taichi/transforms/utils.h index 7ae4535b58c2e..6787364b9bf8d 100644 --- a/taichi/transforms/utils.h +++ b/taichi/transforms/utils.h @@ -2,6 +2,7 @@ namespace taichi::lang { -Stmt *generate_mod_x_div_y(VecStatement *stmts, Stmt *num, int x, int y); +Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y); +Stmt *generate_div(VecStatement *stmts, Stmt *x, int y); } // namespace taichi::lang