Skip to content

Commit

Permalink
[opt] Remove redundant mod for SNode access under packed mode
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Oct 26, 2022
1 parent c273503 commit c60c70a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
14 changes: 13 additions & 1 deletion taichi/transforms/scalar_pointer_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ void ScalarPointerLowerer::run() {
total_shape[j] *= s->extractors[j].shape;
}
}
std::array<bool, taichi_max_num_indices> is_first_extraction;
is_first_extraction.fill(true);

if (path_length_ == 0)
return;
Expand All @@ -80,7 +82,17 @@ void ScalarPointerLowerer::run() {
const int prev = total_shape[k];
total_shape[k] /= snode->extractors[k].shape;
const int next = total_shape[k];
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
if (is_first_extraction[k]) {
// Upon first extraction on axis k, "indices_[k_]" is the user
// coordinate on axis k and "prev" is the total shape of axis k.
// Unless it is an invalid out-of-bound access, we can assume
// "indices_[k_] < prev" so we don't need a mod here.
auto const_next = lowered_->push_back<ConstStmt>(TypedConstant(next));
extracted = lowered_->push_back<BinaryOpStmt>(BinaryOpType::div, indices_[k_], const_next);
is_first_extraction[k] = false;
} else {
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
}
} else {
const int end = start_bits[k];
start_bits[k] -= snode->extractors[k].num_bits;
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/transforms/scalar_pointer_lowerer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,27 @@ TEST_F(ScalarPointerLowererTest, Basic) {
}
}

TEST(ScalarPointerLowerer, EliminateMod) {
const bool kPacked = true;
IRBuilder builder;
VecStatement lowered;
Stmt *index = builder.get_int32(2);
auto root = std::make_unique<SNode>(/*depth=*/0, SNodeType::root);
SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, kPacked));
SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, kPacked));
SNode *dense_3 = &(dense_2->dense({Axis{0}}, /*size=*/5, kPacked));
SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place));
SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place));
LowererImpl lowerer_1{leaf_1, {index, index}, SNodeOpType::undefined,
/*is_bit_vectorized=*/false, &lowered, kPacked};
lowerer_1.run();
LowererImpl lowerer_2{leaf_2, {index}, SNodeOpType::undefined,
/*is_bit_vectorized=*/false, &lowered, kPacked};
lowerer_2.run();
for (int i = 0; i < lowered.size(); i++) {
ASSERT_FALSE(lowered[i]->is<BinaryOpStmt>() &&
lowered[i]->as<BinaryOpStmt>()->op_type == BinaryOpType::mod);
}
}
} // namespace
} // namespace taichi::lang

0 comments on commit c60c70a

Please sign in to comment.