Skip to content

Commit

Permalink
[Lang] Enable packed mode by default (taichi-dev#6721)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6660

### Brief Summary

This PR does the following:
- Provides missing part of packed mode support for Metal sparse;
- Removes `ti.extension.packed`;
- Sets `packed=True` by default;
- Fixes illegal tests.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 53df283 commit cca99ec
Show file tree
Hide file tree
Showing 15 changed files with 38 additions and 52 deletions.
1 change: 0 additions & 1 deletion taichi/inc/extensions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ PER_EXTENSION(adstack) // For keeping the history of mutable local variables
PER_EXTENSION(bls) // Block-local storage
PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels
PER_EXTENSION(extfunc) // Invoke external functions or backend source
PER_EXTENSION(packed) // Shape will not be padded to a power of two
PER_EXTENSION(
dynamic_index) // Dynamic index support for both global and local tensors
2 changes: 1 addition & 1 deletion taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ CompileConfig::CompileConfig() {
simd_width = default_simd_width(arch);
opt_level = 1;
external_optimization_level = 3;
packed = false;
packed = true;
print_ir = false;
print_preprocessed_ir = false;
print_accessor_ir = false;
Expand Down
8 changes: 3 additions & 5 deletions taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::x64,
{Extension::sparse, Extension::quant, Extension::quant_basic,
Extension::data64, Extension::adstack, Extension::assertion,
Extension::extfunc, Extension::packed, Extension::dynamic_index,
Extension::mesh}},
Extension::extfunc, Extension::dynamic_index, Extension::mesh}},
{Arch::arm64,
{Extension::sparse, Extension::quant, Extension::quant_basic,
Extension::data64, Extension::adstack, Extension::assertion,
Extension::packed, Extension::dynamic_index, Extension::mesh}},
Extension::dynamic_index, Extension::mesh}},
{Arch::cuda,
{Extension::sparse, Extension::quant, Extension::quant_basic,
Extension::data64, Extension::adstack, Extension::bls,
Extension::assertion, Extension::packed, Extension::dynamic_index,
Extension::mesh}},
Extension::assertion, Extension::dynamic_index, Extension::mesh}},
// TODO: supporting quant in metal(tests randomly crashed)
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::sparse}},
Expand Down
9 changes: 5 additions & 4 deletions taichi/runtime/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,14 +949,15 @@ class KernelManager::Impl {
const auto *sn = iter->second.snode;
SNodeExtractors *rtm_ext = reinterpret_cast<SNodeExtractors *>(addr) + i;
TI_DEBUG("SNodeExtractors snode={}", i);
rtm_ext->packed = config_->packed;
for (int j = 0; j < taichi_max_num_indices; ++j) {
const auto &ext = sn->extractors[j];
rtm_ext->extractors[j].num_bits = ext.num_bits;
rtm_ext->extractors[j].acc_offset = ext.acc_offset;
rtm_ext->extractors[j].num_elements_from_root =
ext.num_elements_from_root;
TI_DEBUG(" [{}] num_bits={} acc_offset={} num_elements_from_root={}",
j, ext.num_bits, ext.acc_offset, ext.num_elements_from_root);
rtm_ext->extractors[j].shape = ext.shape;
rtm_ext->extractors[j].acc_shape = ext.acc_shape;
TI_DEBUG(" [{}] num_bits={} acc_offset={} shape={} acc_shape={}", j,
ext.num_bits, ext.acc_offset, ext.shape, ext.acc_shape);
}
TI_DEBUG("");
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ STR(

struct SNodeExtractors {
struct Extractor {
int32_t start = 0;
int32_t num_bits = 0;
int32_t acc_offset = 0;
int32_t num_elements_from_root = 0;
int32_t shape = 0;
int32_t acc_shape = 0;
};

Extractor extractors[kTaichiMaxNumIndices];
bool packed;
};

struct ElementCoords { int32_t at[kTaichiMaxNumIndices]; };
Expand Down
18 changes: 13 additions & 5 deletions taichi/runtime/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,19 @@ STR(
device const SNodeExtractors &child_extrators,
int l,
thread ElementCoords *child) {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
const int mask = ((1 << ex.num_bits) - 1);
const int addition = ((l >> ex.acc_offset) & mask);
child->at[i] = ((parent.at[i] << ex.num_bits) | addition);
if (child_extrators.packed) {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
const int addition = l % (ex.acc_shape * ex.shape) / ex.acc_shape;
child->at[i] = parent.at[i] * ex.shape + addition;
}
} else {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
const int mask = ((1 << ex.num_bits) - 1);
const int addition = ((l >> ex.acc_offset) & mask);
child->at[i] = ((parent.at[i] << ex.num_bits) | addition);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) {

Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) {
if (bit::is_power_of_two(y)) {
auto const_stmt =
stmts->push_back<ConstStmt>(TypedConstant(bit::log2int(y)));
auto const_stmt = stmts->push_back<ConstStmt>(
TypedConstant(PrimitiveType::i32, bit::log2int(y)));
return stmts->push_back<BinaryOpStmt>(BinaryOpType::bit_shr, x, const_stmt);
}
auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/aot/gfx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void run_dense_field_kernel(Arch arch, taichi::lang::Device *device) {

// Retrieve kernels/fields/etc from AOT module
auto root_size = vk_module->get_root_size();
EXPECT_EQ(root_size, 64);
EXPECT_EQ(root_size, 40);
gfx_runtime->add_root_buffer(root_size);

auto simple_ret_kernel = vk_module->get_kernel("simple_ret");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/bls_test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def bls_test_template(dim,
create_block().dense(index, bs).place(y)
create_block().dense(index, bs).place(y2)

ndrange = ((bs[i], N - bs[i]) for i in range(dim))
ndrange = ((bs[i] * 2, N - bs[i] * 2) for i in range(dim))

if block_dim is None:
block_dim = 1
Expand Down
6 changes: 2 additions & 4 deletions tests/python/test_bitmasked.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def test_basic():
_test_basic()


@test_utils.test(require=[ti.extension.sparse, ti.extension.packed],
packed=True)
@test_utils.test(require=ti.extension.sparse, packed=True)
def test_basic_packed():
_test_basic()

Expand Down Expand Up @@ -216,8 +215,7 @@ def test_sparsity_changes():
_test_sparsity_changes()


@test_utils.test(require=[ti.extension.sparse, ti.extension.packed],
packed=True)
@test_utils.test(require=ti.extension.sparse, packed=True)
def test_sparsity_changes_packed():
_test_sparsity_changes()

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_mpm_particle_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_mpm_particle_list_no_leakage():


@pytest.mark.run_in_serial
@test_utils.test(require=[ti.extension.sparse, ti.extension.packed],
@test_utils.test(require=ti.extension.sparse,
exclude=[ti.metal],
device_memory_GB=1.0,
packed=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_packed_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tests import test_utils


@test_utils.test(require=ti.extension.packed, packed=True)
@test_utils.test(arch=[ti.cpu, ti.cuda], packed=True)
def test_packed_size():
x = ti.field(ti.i32)
ti.root.dense(ti.l, 3).dense(ti.ijk, 129).place(x)
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_sparse_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def test_pointer2():
_test_pointer2()


@test_utils.test(require=[ti.extension.sparse, ti.extension.packed],
packed=True)
@test_utils.test(require=ti.extension.sparse, packed=True)
def test_pointer2_packed():
_test_pointer2()

Expand Down
6 changes: 2 additions & 4 deletions tests/python/test_struct_for_intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ def test_nested_demote():
_test_nested()


@test_utils.test(require=[ti.extension.sparse, ti.extension.packed],
@test_utils.test(require=ti.extension.sparse,
demote_dense_struct_fors=False,
packed=True)
def test_nested_packed():
_test_nested()


@test_utils.test(require=ti.extension.packed,
demote_dense_struct_fors=True,
packed=True)
@test_utils.test(demote_dense_struct_fors=True, packed=True)
def test_nested_demote_packed():
_test_nested()
19 changes: 2 additions & 17 deletions tests/python/test_struct_for_non_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_1d():
_test_1d()


@test_utils.test(require=ti.extension.packed, packed=True)
@test_utils.test(packed=True)
def test_1d_packed():
_test_1d()

Expand Down Expand Up @@ -63,21 +63,6 @@ def test_2d():
_test_2d()


@test_utils.test(require=ti.extension.packed, packed=True)
@test_utils.test(packed=True)
def test_2d_packed():
_test_2d()


@test_utils.test(require=ti.extension.packed, packed=True)
def test_2d_overflow_if_not_packed():
n, m, p = 2**9 + 1, 2**9 + 1, 2**10 + 1
arr = ti.field(ti.u8, (n, m, p))

@ti.kernel
def count() -> ti.i32:
res = 0
for _ in ti.grouped(arr):
res += 1
return res

assert count() == n * m * p

0 comments on commit cca99ec

Please sign in to comment.