Skip to content

Commit

Permalink
[Lang] [spirv] Support dynamic indexing in spirv (taichi-dev#6990)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#2590

### Brief Summary

This PR supports dynamic indexing in spirv (sister PR: taichi-dev#6985).
`_test_local_matrix_non_constant_index()` is modified due to lack of
`AssertStmt` support in spirv.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and feisuzhu committed Jan 5, 2023
1 parent b5b28de commit 6dacae7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 35 deletions.
1 change: 1 addition & 0 deletions taichi/codegen/spirv/snode_struct_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class StructCompiler {
cell_stride += snode_size;
snode_descriptors_.find(ch_snode->id)
->second.mem_offset_in_parent_cell = child_offset;
ch_snode->offset_bytes_in_parent_cell = child_offset;
}
sn_desc.cell_stride = cell_stride;

Expand Down
53 changes: 33 additions & 20 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,40 +219,53 @@ class TaskCodegen : public IRVisitor {
}

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
// Alloca for shared memory / workgroup memory
if (!alloca->is_shared) {
TI_ERROR(
"Tensor type for dyanmic index is not yet supported on Vulkan.");
}
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());

spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num);
spirv::Value ptr_val = ir_->alloca_workgroup_array(arr_type);
shared_array_binds_.push_back(ptr_val);
ir_->register_value(alloca->raw_name(), ptr_val);
if (alloca->is_shared) { // for shared memory / workgroup memory
ptr_val = ir_->alloca_workgroup_array(arr_type);
shared_array_binds_.push_back(ptr_val);
} else { // for function memory
ptr_val = ir_->alloca_variable(arr_type);
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::Value ptr_val = ir_->alloca_variable(src_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
ir_->register_value(alloca->raw_name(), ptr_val);
}
ir_->register_value(alloca->raw_name(), ptr_val);
}

void visit(MatrixPtrStmt *stmt) override {
spirv::SType data_type =
ir_->get_primitive_type(stmt->element_type().ptr_removed());
spirv::SType ptr_type =
ir_->get_pointer_type(data_type, spv::StorageClassWorkgroup);
auto origin_val = ir_->query_value(stmt->origin->raw_name());
auto offset_val = ir_->query_value(stmt->offset->raw_name());
Value offset_ptr =
ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val);
ir_->register_value(stmt->raw_name(), offset_ptr);
spirv::Value ptr_val;
spirv::Value origin_val = ir_->query_value(stmt->origin->raw_name());
spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
auto dt = stmt->element_type().ptr_removed();
if (stmt->offset_used_as_index()) {
if (stmt->origin->is<AllocaStmt>()) {
spirv::SType ptr_type = ir_->get_pointer_type(
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val,
offset_val);
} else if (stmt->origin->is<GlobalTemporaryStmt>()) {
spirv::Value dt_bytes = ir_->int_immediate_number(
ir_->i32_type(), ir_->get_primitive_type_size(dt), false);
spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val);
ptr_val = ir_->add(origin_val, offset_bytes);
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
} else {
TI_NOT_IMPLEMENTED;
}
} else { // offset used as bytes
ptr_val = ir_->add(origin_val, ir_->cast(origin_val.stype, offset_val));
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
}
ir_->register_value(stmt->raw_name(), ptr_val);
}

void visit(LocalLoadStmt *stmt) override {
Expand Down
4 changes: 3 additions & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::dynamic_index,
Extension::sparse}},
{Arch::opengl, {Extension::extfunc}},
{Arch::opengl, {Extension::dynamic_index, Extension::extfunc}},
{Arch::gles, {}},
{Arch::vulkan, {Extension::dynamic_index}},
{Arch::dx11, {Extension::dynamic_index}},
{Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}},
};
// if (with_opengl_extension_data64())
Expand Down
22 changes: 8 additions & 14 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,39 +166,33 @@ def run():

def _test_local_matrix_non_constant_index():
@ti.kernel
def func1():
def func1() -> ti.types.vector(3, ti.i32):
tmp = ti.Vector([1, 2, 3])
for i in range(3):
vec = ti.Vector([4, 5, 6])
for j in range(3):
vec[tmp[i] % 3] += vec[j]
tmp[i] = vec[tmp[i] % 3]
assert tmp[0] == 24
assert tmp[1] == 30
assert tmp[2] == 19
return tmp

func1()
assert (func1() == ti.Vector([24, 30, 19])).all()

@ti.kernel
def func2(i: ti.i32, j: ti.i32, k: ti.i32):
def func2(i: ti.i32, j: ti.i32, k: ti.i32) -> ti.i32:
tmp = ti.Matrix([[k, k * 2], [k * 2, k * 3]])
assert tmp[i, j] == k * (i + j + 1)
return tmp[i, j]

for i in range(2):
for j in range(2):
func2(i, j, 10)
assert func2(i, j, 10) == 10 * (i + j + 1)


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
@test_utils.test(require=ti.extension.dynamic_index, dynamic_index=True)
def test_local_matrix_non_constant_index():
_test_local_matrix_non_constant_index()


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix_scalarize=False,
debug=True)
@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix_scalarize=False)
def test_local_matrix_non_constant_index_real_matrix():
_test_local_matrix_non_constant_index()

Expand Down

0 comments on commit 6dacae7

Please sign in to comment.