diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index e4c27959e9aec..2e041fc10d0a5 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -90,7 +90,6 @@ void ScalarPointerLowerer::run() { auto const_next = lowered_->push_back(TypedConstant(next)); extracted = lowered_->push_back( BinaryOpType::div, indices_[k_], const_next); - is_first_extraction[k] = false; } else { extracted = generate_mod_x_div_y(lowered_, indices_[k_], prev, next); } @@ -98,9 +97,16 @@ void ScalarPointerLowerer::run() { const int end = start_bits[k]; start_bits[k] -= snode->extractors[k].num_bits; const int begin = start_bits[k]; - extracted = - lowered_->push_back(indices_[k_], begin, end); + if (is_first_extraction[k] && begin == 0) { + // Similar optimization as above. In this case the full user + // coordinate is extracted so we don't need a BitExtractStmt. + extracted = indices_[k_]; + } else { + extracted = + lowered_->push_back(indices_[k_], begin, end); + } } + is_first_extraction[k] = false; lowered_indices.push_back(extracted); strides.push_back(snode->extractors[k].shape); } diff --git a/tests/python/test_sparse_activate.py b/tests/python/test_sparse_activate.py index 3c5d498204263..7c8d96fb110fd 100644 --- a/tests/python/test_sparse_activate.py +++ b/tests/python/test_sparse_activate.py @@ -17,8 +17,8 @@ def test_pointer(): @ti.kernel def activate(): - ti.activate(ptr, 1) - ti.activate(ptr, 32) + ti.activate(ptr, ti.rescale_index(x, ptr, [1])) + ti.activate(ptr, ti.rescale_index(x, ptr, [32])) @ti.kernel def func(): diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index c7e6e5e20782e..4f51e4df11b54 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -34,13 +34,14 @@ def test_pointer_is_active(): n = 128 - ti.root.pointer(ti.i, n).dense(ti.i, n).place(x) + ptr = ti.root.pointer(ti.i, n) + ptr.dense(ti.i, n).place(x) ti.root.place(s) @ti.kernel def func(): for i in range(n * n): - s[None] += ti.is_active(x.parent().parent(), i) + s[None] += ti.is_active(ptr, ti.rescale_index(x, ptr, [i])) x[0] = 1 x[127] = 1 diff --git a/tests/python/test_sparse_parallel.py b/tests/python/test_sparse_parallel.py index 1a797c15cfeb7..c038a9e1c6a7d 100644 --- a/tests/python/test_sparse_parallel.py +++ b/tests/python/test_sparse_parallel.py @@ -58,7 +58,8 @@ def test_nested_struct_fill_and_clear(): a = ti.field(dtype=ti.f32) N = 512 - ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) + ptr = ti.root.pointer(ti.ij, [N, N]) + ptr.dense(ti.ij, [8, 8]).place(a) @ti.kernel def fill(): @@ -68,7 +69,7 @@ def fill(): @ti.kernel def clear(): for i, j in a.parent(): - ti.deactivate(a.parent().parent(), [i, j]) + ti.deactivate(ptr, ti.rescale_index(a, ptr, [i, j])) def task(): fill()