From 2e42a07ed824b42022486363d54388a0c7dc978b Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 1 Nov 2022 13:50:06 +0800 Subject: [PATCH] [opt] Eliminate redundant BitExtractStmt for SNode access under non-packed mode (#6485) Issue: #6219 ### Brief Summary This PR adds optimization similar to #6444 for non-packed mode so that we can conduct fair comparisons regarding performance. After this PR, the benchmark script in #6219 runs `0.007s` on my local machine no matter `packed=True/False`. The tests are fixed because they are invalid - the out-of-bound access used to be hidden by the always inserted `BitExtractStmt` before this PR. --- taichi/transforms/scalar_pointer_lowerer.cpp | 12 +++++++++--- tests/python/test_sparse_activate.py | 4 ++-- tests/python/test_sparse_basics.py | 5 +++-- tests/python/test_sparse_parallel.py | 5 +++-- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index e4c27959e9aec..2e041fc10d0a5 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -90,7 +90,6 @@ void ScalarPointerLowerer::run() { auto const_next = lowered_->push_back(TypedConstant(next)); extracted = lowered_->push_back( BinaryOpType::div, indices_[k_], const_next); - is_first_extraction[k] = false; } else { extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); } @@ -98,9 +97,16 @@ void ScalarPointerLowerer::run() { const int end = start_bits[k]; start_bits[k] -= snode->extractors[k].num_bits; const int begin = start_bits[k]; - extracted = - lowered_->push_back(indices_[k_], begin, end); + if (is_first_extraction[k] && begin == 0) { + // Similar optimization as above. In this case the full user + // coordinate is extracted so we don't need a BitExtractStmt. + extracted = indices_[k_]; + } else { + extracted = + lowered_->push_back(indices_[k_], begin, end); + } } + is_first_extraction[k] = false; lowered_indices.push_back(extracted); strides.push_back(snode->extractors[k].shape); } diff --git a/tests/python/test_sparse_activate.py b/tests/python/test_sparse_activate.py index 3c5d498204263..7c8d96fb110fd 100644 --- a/tests/python/test_sparse_activate.py +++ b/tests/python/test_sparse_activate.py @@ -17,8 +17,8 @@ def test_pointer(): @ti.kernel def activate(): - ti.activate(ptr, 1) - ti.activate(ptr, 32) + ti.activate(ptr, ti.rescale_index(x, ptr, [1])) + ti.activate(ptr, ti.rescale_index(x, ptr, [32])) @ti.kernel def func(): diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index c7e6e5e20782e..4f51e4df11b54 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -34,13 +34,14 @@ def test_pointer_is_active(): n = 128 - ti.root.pointer(ti.i, n).dense(ti.i, n).place(x) + ptr = ti.root.pointer(ti.i, n) + ptr.dense(ti.i, n).place(x) ti.root.place(s) @ti.kernel def func(): for i in range(n * n): - s[None] += ti.is_active(x.parent().parent(), i) + s[None] += ti.is_active(ptr, ti.rescale_index(x, ptr, [i])) x[0] = 1 x[127] = 1 diff --git a/tests/python/test_sparse_parallel.py b/tests/python/test_sparse_parallel.py index 1a797c15cfeb7..c038a9e1c6a7d 100644 --- a/tests/python/test_sparse_parallel.py +++ b/tests/python/test_sparse_parallel.py @@ -58,7 +58,8 @@ def test_nested_struct_fill_and_clear(): a = ti.field(dtype=ti.f32) N = 512 - ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) + ptr = ti.root.pointer(ti.ij, [N, N]) + ptr.dense(ti.ij, [8, 8]).place(a) @ti.kernel def fill(): @@ -68,7 +69,7 @@ def fill(): @ti.kernel def clear(): for i, j in a.parent(): - ti.deactivate(a.parent().parent(), [i, j]) + ti.deactivate(ptr, ti.rescale_index(a, ptr, [i, j])) def task(): fill()