Skip to content

Commit

Permalink
[opt] Eliminate redundant mod for SNode access under packed mode (#6444)
Browse files Browse the repository at this point in the history
Issue: #6219

### Brief Summary
 
For the following example,
```python
ti.root.dense(ti.i, 10).dense(ti.i, 30).place(x)
```
Under packed mode, if we want to access `x[105]`, we will calculate `105
mod (10 * 30) div 30 = 3` for the coordinate in the first dense SNode,
and `105 mod 30 = 15` for the coordinate in the second dense SNode. We
can see that `105 mod (10 * 30)` is unnecessary because user coordinate
(`105`) is always less than the total shape (`10 * 30`) of the axis.
This PR eliminates such redundant `mod` upon first coordinate extraction
on an axis.

On my local machine, the benchmark script in #6219 runs
`0.030s` for `packed=False`,
`0.039s` for `packed=True` before this PR,
`0.007s` for `packed=True` after this PR (even faster than
`packed=False` because `packed=False` still generates a
`BitExtractStmt`).

This optimization ensures that no `mod` will be generated for accessing
`x[i, j]` in common use cases like
```python
x = ti.field(ti.i32, shape=(100, 200), order='ji')
# or equivalently
x = ti.field(ti.i32)
ti.root.dense(ti.j, 200).dense(ti.i, 100).place(x)
```

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Oct 27, 2022
1 parent 9f96e51 commit 13b78ab
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
15 changes: 14 additions & 1 deletion taichi/transforms/scalar_pointer_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ void ScalarPointerLowerer::run() {
total_shape[j] *= s->extractors[j].shape;
}
}
std::array<bool, taichi_max_num_indices> is_first_extraction;
is_first_extraction.fill(true);

if (path_length_ == 0)
return;
Expand All @@ -80,7 +82,18 @@ void ScalarPointerLowerer::run() {
const int prev = total_shape[k];
total_shape[k] /= snode->extractors[k].shape;
const int next = total_shape[k];
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
if (is_first_extraction[k]) {
// Upon first extraction on axis k, "indices_[k_]" is the user
// coordinate on axis k and "prev" is the total shape of axis k.
// Unless it is an invalid out-of-bound access, we can assume
// "indices_[k_] < prev" so we don't need a mod here.
auto const_next = lowered_->push_back<ConstStmt>(TypedConstant(next));
extracted = lowered_->push_back<BinaryOpStmt>(
BinaryOpType::div, indices_[k_], const_next);
is_first_extraction[k] = false;
} else {
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
}
} else {
const int end = start_bits[k];
start_bits[k] -= snode->extractors[k].num_bits;
Expand Down
30 changes: 30 additions & 0 deletions tests/cpp/transforms/scalar_pointer_lowerer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,35 @@ TEST_F(ScalarPointerLowererTest, Basic) {
}
}

TEST(ScalarPointerLowerer, EliminateMod) {
const bool kPacked = true;
IRBuilder builder;
VecStatement lowered;
Stmt *index = builder.get_int32(2);
auto root = std::make_unique<SNode>(/*depth=*/0, SNodeType::root);
SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, kPacked));
SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, kPacked));
SNode *dense_3 = &(dense_2->dense({Axis{0}}, /*size=*/5, kPacked));
SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place));
SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place));
LowererImpl lowerer_1{leaf_1,
{index, index},
SNodeOpType::undefined,
/*is_bit_vectorized=*/false,
&lowered,
kPacked};
lowerer_1.run();
LowererImpl lowerer_2{leaf_2,
{index},
SNodeOpType::undefined,
/*is_bit_vectorized=*/false,
&lowered,
kPacked};
lowerer_2.run();
for (int i = 0; i < lowered.size(); i++) {
ASSERT_FALSE(lowered[i]->is<BinaryOpStmt>() &&
lowered[i]->as<BinaryOpStmt>()->op_type == BinaryOpType::mod);
}
}
} // namespace
} // namespace taichi::lang

0 comments on commit 13b78ab

Please sign in to comment.