Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Eliminate redundant mod in demote_dense_struct_fors under packed mode #6709

Merged
merged 5 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions taichi/transforms/demote_dense_struct_fors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,25 @@ void convert_to_range_for(OffloadedStmt *offloaded, bool packed) {
if (packed) { // no dependence on POT
for (int i = 0; i < (int)snodes.size(); i++) {
auto snode = snodes[i];
auto extracted =
generate_mod_x_div_y(&body_header, main_loop_var, total_n,
total_n / snode->num_cells_per_container);
Stmt *extracted = main_loop_var;
if (i != 0) { // first extraction doesn't need a mod
extracted = generate_mod(&body_header, extracted, total_n);
}
total_n /= snode->num_cells_per_container;
extracted = generate_div(&body_header, extracted, total_n);
bool is_first_extraction = true;
for (int j = 0; j < (int)physical_indices.size(); j++) {
auto p = physical_indices[j];
auto ext = snode->extractors[p];
auto index = generate_mod_x_div_y(
&body_header, extracted, ext.acc_shape * ext.shape, ext.acc_shape);
if (!ext.active)
continue;
Stmt *index = extracted;
if (is_first_extraction) { // first extraction doesn't need a mod
is_first_extraction = false;
} else {
index = generate_mod(&body_header, index, ext.acc_shape * ext.shape);
}
index = generate_div(&body_header, index, ext.acc_shape);
total_shape[p] /= ext.shape;
auto multiplier =
body_header.push_back<ConstStmt>(TypedConstant(total_shape[p]));
Expand Down
15 changes: 7 additions & 8 deletions taichi/transforms/scalar_pointer_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,16 @@ void ScalarPointerLowerer::run() {
const int prev = total_shape[k];
total_shape[k] /= snode->extractors[k].shape;
const int next = total_shape[k];
// Upon first extraction on axis k, "indices_[k_]" is the user
// coordinate on axis k and "prev" is the total shape of axis k.
// Unless it is an invalid out-of-bound access, we can assume
// "indices_[k_] < prev" so we don't need a mod here.
if (is_first_extraction[k]) {
// Upon first extraction on axis k, "indices_[k_]" is the user
// coordinate on axis k and "prev" is the total shape of axis k.
// Unless it is an invalid out-of-bound access, we can assume
// "indices_[k_] < prev" so we don't need a mod here.
auto const_next = lowered_->push_back<ConstStmt>(TypedConstant(next));
extracted = lowered_->push_back<BinaryOpStmt>(
BinaryOpType::div, indices_[k_], const_next);
extracted = indices_[k_];
} else {
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
extracted = generate_mod(lowered_, indices_[k_], prev);
}
extracted = generate_div(lowered_, extracted, next);
} else {
const int end = start_bits[k];
start_bits[k] -= snode->extractors[k].num_bits;
Expand Down
13 changes: 8 additions & 5 deletions taichi/transforms/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

namespace taichi::lang {

Stmt *generate_mod_x_div_y(VecStatement *stmts, Stmt *num, int x, int y) {
auto const_x = stmts->push_back<ConstStmt>(TypedConstant(x));
auto mod_x = stmts->push_back<BinaryOpStmt>(BinaryOpType::mod, num, const_x);
auto const_y = stmts->push_back<ConstStmt>(TypedConstant(y));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::div, mod_x, const_y);
Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) {
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::mod, x, const_stmt);
}

Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) {
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::div, x, const_stmt);
}

} // namespace taichi::lang
3 changes: 2 additions & 1 deletion taichi/transforms/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace taichi::lang {

Stmt *generate_mod_x_div_y(VecStatement *stmts, Stmt *num, int x, int y);
Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y);
Stmt *generate_div(VecStatement *stmts, Stmt *x, int y);

} // namespace taichi::lang