diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 1668bce569588..36a9676eacff7 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -8,6 +8,5 @@ PER_EXTENSION(adstack) // For keeping the history of mutable local variables PER_EXTENSION(bls) // Block-local storage PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels PER_EXTENSION(extfunc) // Invoke external functions or backend source -PER_EXTENSION(packed) // Shape will not be padded to a power of two PER_EXTENSION( dynamic_index) // Dynamic index support for both global and local tensors diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index b38f76a8e6be3..760d6653686f7 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -9,7 +9,7 @@ CompileConfig::CompileConfig() { simd_width = default_simd_width(arch); opt_level = 1; external_optimization_level = 3; - packed = false; + packed = true; print_ir = false; print_preprocessed_ir = false; print_accessor_ir = false; diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 3de3de85978ee..1c7b86315ffb5 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -10,17 +10,15 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::x64, {Extension::sparse, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::assertion, - Extension::extfunc, Extension::packed, Extension::dynamic_index, - Extension::mesh}}, + Extension::extfunc, Extension::dynamic_index, Extension::mesh}}, {Arch::arm64, {Extension::sparse, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::assertion, - Extension::packed, Extension::dynamic_index, Extension::mesh}}, + Extension::dynamic_index, Extension::mesh}}, {Arch::cuda, {Extension::sparse, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::bls, - Extension::assertion, Extension::packed, Extension::dynamic_index, - Extension::mesh}}, + Extension::assertion, Extension::dynamic_index, Extension::mesh}}, // TODO: supporting quant in metal(tests randomly crashed) {Arch::metal, {Extension::adstack, Extension::assertion, Extension::sparse}}, diff --git a/taichi/runtime/metal/kernel_manager.cpp b/taichi/runtime/metal/kernel_manager.cpp index f30dd8e477a4b..b40ec187fd9cf 100644 --- a/taichi/runtime/metal/kernel_manager.cpp +++ b/taichi/runtime/metal/kernel_manager.cpp @@ -949,14 +949,15 @@ class KernelManager::Impl { const auto *sn = iter->second.snode; SNodeExtractors *rtm_ext = reinterpret_cast(addr) + i; TI_DEBUG("SNodeExtractors snode={}", i); + rtm_ext->packed = config_->packed; for (int j = 0; j < taichi_max_num_indices; ++j) { const auto &ext = sn->extractors[j]; rtm_ext->extractors[j].num_bits = ext.num_bits; rtm_ext->extractors[j].acc_offset = ext.acc_offset; - rtm_ext->extractors[j].num_elements_from_root = - ext.num_elements_from_root; - TI_DEBUG(" [{}] num_bits={} acc_offset={} num_elements_from_root={}", - j, ext.num_bits, ext.acc_offset, ext.num_elements_from_root); + rtm_ext->extractors[j].shape = ext.shape; + rtm_ext->extractors[j].acc_shape = ext.acc_shape; + TI_DEBUG(" [{}] num_bits={} acc_offset={} shape={} acc_shape={}", j, + ext.num_bits, ext.acc_offset, ext.shape, ext.acc_shape); } TI_DEBUG(""); } diff --git a/taichi/runtime/metal/shaders/runtime_structs.metal.h b/taichi/runtime/metal/shaders/runtime_structs.metal.h index 4e6483628ea4c..4989c13ef8094 100644 --- a/taichi/runtime/metal/shaders/runtime_structs.metal.h +++ b/taichi/runtime/metal/shaders/runtime_structs.metal.h @@ -114,13 +114,13 @@ STR( struct SNodeExtractors { struct Extractor { - int32_t start = 0; int32_t num_bits = 0; int32_t acc_offset = 0; - int32_t num_elements_from_root = 0; + int32_t shape = 0; + int32_t acc_shape = 0; }; - Extractor extractors[kTaichiMaxNumIndices]; + bool packed; }; struct ElementCoords { int32_t at[kTaichiMaxNumIndices]; }; diff --git a/taichi/runtime/metal/shaders/runtime_utils.metal.h b/taichi/runtime/metal/shaders/runtime_utils.metal.h index f75dd26929610..ec65f1949615b 100644 --- a/taichi/runtime/metal/shaders/runtime_utils.metal.h +++ b/taichi/runtime/metal/shaders/runtime_utils.metal.h @@ -422,11 +422,19 @@ STR( device const SNodeExtractors &child_extrators, int l, thread ElementCoords *child) { - for (int i = 0; i < kTaichiMaxNumIndices; ++i) { - device const auto &ex = child_extrators.extractors[i]; - const int mask = ((1 << ex.num_bits) - 1); - const int addition = ((l >> ex.acc_offset) & mask); - child->at[i] = ((parent.at[i] << ex.num_bits) | addition); + if (child_extrators.packed) { + for (int i = 0; i < kTaichiMaxNumIndices; ++i) { + device const auto &ex = child_extrators.extractors[i]; + const int addition = l % (ex.acc_shape * ex.shape) / ex.acc_shape; + child->at[i] = parent.at[i] * ex.shape + addition; + } + } else { + for (int i = 0; i < kTaichiMaxNumIndices; ++i) { + device const auto &ex = child_extrators.extractors[i]; + const int mask = ((1 << ex.num_bits) - 1); + const int addition = ((l >> ex.acc_offset) & mask); + child->at[i] = ((parent.at[i] << ex.num_bits) | addition); + } } } diff --git a/taichi/transforms/utils.cpp b/taichi/transforms/utils.cpp index 13fd326155f4b..9c5e7bf5eeace 100644 --- a/taichi/transforms/utils.cpp +++ b/taichi/transforms/utils.cpp @@ -13,8 +13,8 @@ Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) { Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) { if (bit::is_power_of_two(y)) { - auto const_stmt = - stmts->push_back(TypedConstant(bit::log2int(y))); + auto const_stmt = stmts->push_back( + TypedConstant(PrimitiveType::i32, bit::log2int(y))); return stmts->push_back(BinaryOpType::bit_shr, x, const_stmt); } auto const_stmt = stmts->push_back(TypedConstant(y)); diff --git a/tests/cpp/aot/gfx_utils.cpp b/tests/cpp/aot/gfx_utils.cpp index 4e735cb08ef12..284c124c0315a 100644 --- a/tests/cpp/aot/gfx_utils.cpp +++ b/tests/cpp/aot/gfx_utils.cpp @@ -74,7 +74,7 @@ void run_dense_field_kernel(Arch arch, taichi::lang::Device *device) { // Retrieve kernels/fields/etc from AOT module auto root_size = vk_module->get_root_size(); - EXPECT_EQ(root_size, 64); + EXPECT_EQ(root_size, 40); gfx_runtime->add_root_buffer(root_size); auto simple_ret_kernel = vk_module->get_kernel("simple_ret"); diff --git a/tests/python/bls_test_template.py b/tests/python/bls_test_template.py index fb7c3a1f4b670..0e15bf281eb7d 100644 --- a/tests/python/bls_test_template.py +++ b/tests/python/bls_test_template.py @@ -39,7 +39,7 @@ def bls_test_template(dim, create_block().dense(index, bs).place(y) create_block().dense(index, bs).place(y2) - ndrange = ((bs[i], N - bs[i]) for i in range(dim)) + ndrange = ((bs[i] * 2, N - bs[i] * 2) for i in range(dim)) if block_dim is None: block_dim = 1 diff --git a/tests/python/test_bitmasked.py b/tests/python/test_bitmasked.py index 38babb36aab4a..8fc082ef5fc3f 100644 --- a/tests/python/test_bitmasked.py +++ b/tests/python/test_bitmasked.py @@ -35,8 +35,7 @@ def test_basic(): _test_basic() -@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], - packed=True) +@test_utils.test(require=ti.extension.sparse, packed=True) def test_basic_packed(): _test_basic() @@ -216,8 +215,7 @@ def test_sparsity_changes(): _test_sparsity_changes() -@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], - packed=True) +@test_utils.test(require=ti.extension.sparse, packed=True) def test_sparsity_changes_packed(): _test_sparsity_changes() diff --git a/tests/python/test_mpm_particle_list.py b/tests/python/test_mpm_particle_list.py index b85a69ad508f1..c23f4f1acf7e6 100644 --- a/tests/python/test_mpm_particle_list.py +++ b/tests/python/test_mpm_particle_list.py @@ -57,7 +57,7 @@ def test_mpm_particle_list_no_leakage(): @pytest.mark.run_in_serial -@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], +@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal], device_memory_GB=1.0, packed=True) diff --git a/tests/python/test_packed_size.py b/tests/python/test_packed_size.py index a5802a52aed25..cfb6ee4e2d9da 100644 --- a/tests/python/test_packed_size.py +++ b/tests/python/test_packed_size.py @@ -2,7 +2,7 @@ from tests import test_utils -@test_utils.test(require=ti.extension.packed, packed=True) +@test_utils.test(arch=[ti.cpu, ti.cuda], packed=True) def test_packed_size(): x = ti.field(ti.i32) ti.root.dense(ti.l, 3).dense(ti.ijk, 129).place(x) diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index 4f51e4df11b54..f0d8b744c4f54 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -108,8 +108,7 @@ def test_pointer2(): _test_pointer2() -@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], - packed=True) +@test_utils.test(require=ti.extension.sparse, packed=True) def test_pointer2_packed(): _test_pointer2() diff --git a/tests/python/test_struct_for_intermediate.py b/tests/python/test_struct_for_intermediate.py index 4971f20f19f61..91c2ed4e0545b 100644 --- a/tests/python/test_struct_for_intermediate.py +++ b/tests/python/test_struct_for_intermediate.py @@ -33,15 +33,13 @@ def test_nested_demote(): _test_nested() -@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], +@test_utils.test(require=ti.extension.sparse, demote_dense_struct_fors=False, packed=True) def test_nested_packed(): _test_nested() -@test_utils.test(require=ti.extension.packed, - demote_dense_struct_fors=True, - packed=True) +@test_utils.test(demote_dense_struct_fors=True, packed=True) def test_nested_demote_packed(): _test_nested() diff --git a/tests/python/test_struct_for_non_pot.py b/tests/python/test_struct_for_non_pot.py index dd35434dfe517..9632fde3383e8 100644 --- a/tests/python/test_struct_for_non_pot.py +++ b/tests/python/test_struct_for_non_pot.py @@ -27,7 +27,7 @@ def test_1d(): _test_1d() -@test_utils.test(require=ti.extension.packed, packed=True) +@test_utils.test(packed=True) def test_1d_packed(): _test_1d() @@ -63,21 +63,6 @@ def test_2d(): _test_2d() -@test_utils.test(require=ti.extension.packed, packed=True) +@test_utils.test(packed=True) def test_2d_packed(): _test_2d() - - -@test_utils.test(require=ti.extension.packed, packed=True) -def test_2d_overflow_if_not_packed(): - n, m, p = 2**9 + 1, 2**9 + 1, 2**10 + 1 - arr = ti.field(ti.u8, (n, m, p)) - - @ti.kernel - def count() -> ti.i32: - res = 0 - for _ in ti.grouped(arr): - res += 1 - return res - - assert count() == n * m * p