Skip to content

Commit

Permalink
[opt] Eliminate redundant BitExtractStmt for SNode access under non-p…
Browse files Browse the repository at this point in the history
…acked mode (taichi-dev#6485)

Issue: taichi-dev#6219

### Brief Summary

This PR adds optimization similar to taichi-dev#6444 for non-packed mode so that
we can conduct fair comparisons regarding performance. After this PR,
the benchmark script in taichi-dev#6219 runs `0.007s` on my local machine no
matter `packed=True/False`.

The tests are fixed because they are invalid - the out-of-bound access
used to be hidden by the always inserted `BitExtractStmt` before this
PR.
  • Loading branch information
strongoier authored and quadpixels committed May 13, 2023
1 parent 98f6605 commit 2e42a07
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
12 changes: 9 additions & 3 deletions taichi/transforms/scalar_pointer_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,23 @@ void ScalarPointerLowerer::run() {
auto const_next = lowered_->push_back<ConstStmt>(TypedConstant(next));
extracted = lowered_->push_back<BinaryOpStmt>(
BinaryOpType::div, indices_[k_], const_next);
is_first_extraction[k] = false;
} else {
extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next);
}
} else {
const int end = start_bits[k];
start_bits[k] -= snode->extractors[k].num_bits;
const int begin = start_bits[k];
extracted =
lowered_->push_back<BitExtractStmt>(indices_[k_], begin, end);
if (is_first_extraction[k] && begin == 0) {
// Similar optimization as above. In this case the full user
// coordinate is extracted so we don't need a BitExtractStmt.
extracted = indices_[k_];
} else {
extracted =
lowered_->push_back<BitExtractStmt>(indices_[k_], begin, end);
}
}
is_first_extraction[k] = false;
lowered_indices.push_back(extracted);
strides.push_back(snode->extractors[k].shape);
}
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_sparse_activate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_pointer():

@ti.kernel
def activate():
ti.activate(ptr, 1)
ti.activate(ptr, 32)
ti.activate(ptr, ti.rescale_index(x, ptr, [1]))
ti.activate(ptr, ti.rescale_index(x, ptr, [32]))

@ti.kernel
def func():
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_sparse_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ def test_pointer_is_active():

n = 128

ti.root.pointer(ti.i, n).dense(ti.i, n).place(x)
ptr = ti.root.pointer(ti.i, n)
ptr.dense(ti.i, n).place(x)
ti.root.place(s)

@ti.kernel
def func():
for i in range(n * n):
s[None] += ti.is_active(x.parent().parent(), i)
s[None] += ti.is_active(ptr, ti.rescale_index(x, ptr, [i]))

x[0] = 1
x[127] = 1
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_sparse_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def test_nested_struct_fill_and_clear():
a = ti.field(dtype=ti.f32)
N = 512

ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a)
ptr = ti.root.pointer(ti.ij, [N, N])
ptr.dense(ti.ij, [8, 8]).place(a)

@ti.kernel
def fill():
Expand All @@ -68,7 +69,7 @@ def fill():
@ti.kernel
def clear():
for i, j in a.parent():
ti.deactivate(a.parent().parent(), [i, j])
ti.deactivate(ptr, ti.rescale_index(a, ptr, [i, j]))

def task():
fill()
Expand Down

0 comments on commit 2e42a07

Please sign in to comment.