From 457bd11cc76048d0aef8e828b3dffbbb18fd9c13 Mon Sep 17 00:00:00 2001 From: Proton Date: Tue, 1 Nov 2022 15:32:52 +0800 Subject: [PATCH] [opt] Revert "Eliminate redundant BitExtractStmt for SNode access under non-packed mode (#6491) This reverts commit 7cdb97fb94d12a9a9357bdfbe87440140c3308d6. Issue: # ### Brief Summary --- taichi/transforms/scalar_pointer_lowerer.cpp | 12 +++--------- tests/python/test_sparse_activate.py | 4 ++-- tests/python/test_sparse_basics.py | 5 ++--- tests/python/test_sparse_parallel.py | 5 ++--- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index 2e041fc10d0a5..e4c27959e9aec 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -90,6 +90,7 @@ void ScalarPointerLowerer::run() { auto const_next = lowered_->push_back(TypedConstant(next)); extracted = lowered_->push_back( BinaryOpType::div, indices_[k_], const_next); + is_first_extraction[k] = false; } else { extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); } @@ -97,16 +98,9 @@ void ScalarPointerLowerer::run() { const int end = start_bits[k]; start_bits[k] -= snode->extractors[k].num_bits; const int begin = start_bits[k]; - if (is_first_extraction[k] && begin == 0) { - // Similar optimization as above. In this case the full user - // coordinate is extracted so we don't need a BitExtractStmt. - extracted = indices_[k_]; - } else { - extracted = - lowered_->push_back(indices_[k_], begin, end); - } + extracted = + lowered_->push_back(indices_[k_], begin, end); } - is_first_extraction[k] = false; lowered_indices.push_back(extracted); strides.push_back(snode->extractors[k].shape); } diff --git a/tests/python/test_sparse_activate.py b/tests/python/test_sparse_activate.py index 7c8d96fb110fd..3c5d498204263 100644 --- a/tests/python/test_sparse_activate.py +++ b/tests/python/test_sparse_activate.py @@ -17,8 +17,8 @@ def test_pointer(): @ti.kernel def activate(): - ti.activate(ptr, ti.rescale_index(x, ptr, [1])) - ti.activate(ptr, ti.rescale_index(x, ptr, [32])) + ti.activate(ptr, 1) + ti.activate(ptr, 32) @ti.kernel def func(): diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index 4f51e4df11b54..c7e6e5e20782e 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -34,14 +34,13 @@ def test_pointer_is_active(): n = 128 - ptr = ti.root.pointer(ti.i, n) - ptr.dense(ti.i, n).place(x) + ti.root.pointer(ti.i, n).dense(ti.i, n).place(x) ti.root.place(s) @ti.kernel def func(): for i in range(n * n): - s[None] += ti.is_active(ptr, ti.rescale_index(x, ptr, [i])) + s[None] += ti.is_active(x.parent().parent(), i) x[0] = 1 x[127] = 1 diff --git a/tests/python/test_sparse_parallel.py b/tests/python/test_sparse_parallel.py index c038a9e1c6a7d..1a797c15cfeb7 100644 --- a/tests/python/test_sparse_parallel.py +++ b/tests/python/test_sparse_parallel.py @@ -58,8 +58,7 @@ def test_nested_struct_fill_and_clear(): a = ti.field(dtype=ti.f32) N = 512 - ptr = ti.root.pointer(ti.ij, [N, N]) - ptr.dense(ti.ij, [8, 8]).place(a) + ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) @ti.kernel def fill(): @@ -69,7 +68,7 @@ def fill(): @ti.kernel def clear(): for i, j in a.parent(): - ti.deactivate(ptr, ti.rescale_index(a, ptr, [i, j])) + ti.deactivate(a.parent().parent(), [i, j]) def task(): fill()