Skip to content

Commit

Permalink
[ir] Except shared array from demote atomics pass. (taichi-dev#7513)
Browse files Browse the repository at this point in the history
Fixes: taichi-dev#7510

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent ed21306 commit 76f91b7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
10 changes: 9 additions & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class TaskCodegen : public IRVisitor {
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val,
offset_val);
if (stmt->origin->as<AllocaStmt>()->is_shared) {
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
}
} else if (stmt->origin->is<GlobalTemporaryStmt>()) {
spirv::Value dt_bytes = ir_->int_immediate_number(
ir_->i32_type(), ir_->get_primitive_type_size(dt), false);
Expand Down Expand Up @@ -1453,7 +1456,12 @@ class TaskCodegen : public IRVisitor {
addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt));
}
} else {
addr_ptr = at_buffer(stmt->dest, dt);
if (stmt->dest->is<MatrixPtrStmt>()) {
// Shared arrays have already created an accesschain, use it directly.
addr_ptr = ir_->query_value(stmt->dest->raw_name());
} else {
addr_ptr = at_buffer(stmt->dest, dt);
}
}

auto ret_type = ir_->get_primitive_type(dt);
Expand Down
21 changes: 16 additions & 5 deletions taichi/transforms/demote_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,22 @@ class DemoteAtomics : public BasicStmtVisitor {
}
}
}
if (stmt->dest->is<AllocaStmt>() ||
(stmt->dest->is<MatrixPtrStmt>() &&
stmt->dest->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>())) {
demote = true;
is_local = true;
if (stmt->dest->is<AllocaStmt>()) {
// Except shared array
if (!stmt->dest->as<AllocaStmt>()->is_shared) {
demote = true;
is_local = true;
}
}
if (stmt->dest->is<MatrixPtrStmt>() &&
stmt->dest->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>()) {
// Except shared array
if (!stmt->dest->cast<MatrixPtrStmt>()
->origin->as<AllocaStmt>()
->is_shared) {
demote = true;
is_local = true;
}
}

if (auto dest_pointer_type = stmt->dest->ret_type->cast<PointerType>()) {
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_shared_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,31 @@ def calc_shared_array(v: ti.types.ndarray(ndim=1),
calc(v_arr, d_arr, reference)
calc_shared_array(v_arr, d_arr, a_arr)
assert np.allclose(reference, a_arr)


@test_utils.test(arch=[ti.cuda, ti.vulkan, ti.amdgpu])
def test_shared_array_atomics():
N = 256
block_dim = 32

@ti.kernel
def atomic_test(out: ti.types.ndarray()):
ti.loop_config(block_dim=block_dim)
for i in range(N):
tid = i % block_dim
val = tid
sharr = ti.simt.block.SharedArray((block_dim, ), ti.i32)
sharr[tid] = val
ti.simt.block.sync()
sharr[0] += val
ti.simt.block.sync()
out[i] = sharr[tid]

arr = ti.ndarray(ti.i32, (N))
atomic_test(arr)
ti.sync()
sum = block_dim * (block_dim - 1) // 2
assert arr[0] == sum
assert arr[32] == sum
assert arr[128] == sum
assert arr[224] == sum

0 comments on commit 76f91b7

Please sign in to comment.