diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index f4e41bd787df6..e4c27959e9aec 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -55,6 +55,8 @@ void ScalarPointerLowerer::run() { total_shape[j] *= s->extractors[j].shape; } } + std::array is_first_extraction; + is_first_extraction.fill(true); if (path_length_ == 0) return; @@ -80,7 +82,18 @@ void ScalarPointerLowerer::run() { const int prev = total_shape[k]; total_shape[k] /= snode->extractors[k].shape; const int next = total_shape[k]; - extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); + if (is_first_extraction[k]) { + // Upon first extraction on axis k, "indices_[k_]" is the user + // coordinate on axis k and "prev" is the total shape of axis k. + // Unless it is an invalid out-of-bound access, we can assume + // "indices_[k_] < prev" so we don't need a mod here. + auto const_next = lowered_->push_back(TypedConstant(next)); + extracted = lowered_->push_back( + BinaryOpType::div, indices_[k_], const_next); + is_first_extraction[k] = false; + } else { + extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); + } } else { const int end = start_bits[k]; start_bits[k] -= snode->extractors[k].num_bits; diff --git a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp index 939c4ffb34b42..dbfeabfe0c22c 100644 --- a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp +++ b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp @@ -103,5 +103,35 @@ TEST_F(ScalarPointerLowererTest, Basic) { } } +TEST(ScalarPointerLowerer, EliminateMod) { + const bool kPacked = true; + IRBuilder builder; + VecStatement lowered; + Stmt *index = builder.get_int32(2); + auto root = std::make_unique(/*depth=*/0, SNodeType::root); + SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, kPacked)); + SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, kPacked)); + SNode *dense_3 = &(dense_2->dense({Axis{0}}, /*size=*/5, kPacked)); + SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place)); + SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place)); + LowererImpl lowerer_1{leaf_1, + {index, index}, + SNodeOpType::undefined, + /*is_bit_vectorized=*/false, + &lowered, + kPacked}; + lowerer_1.run(); + LowererImpl lowerer_2{leaf_2, + {index}, + SNodeOpType::undefined, + /*is_bit_vectorized=*/false, + &lowered, + kPacked}; + lowerer_2.run(); + for (int i = 0; i < lowered.size(); i++) { + ASSERT_FALSE(lowered[i]->is() && + lowered[i]->as()->op_type == BinaryOpType::mod); + } +} } // namespace } // namespace taichi::lang