diff --git a/taichi/codegen/spirv/snode_struct_compiler.cpp b/taichi/codegen/spirv/snode_struct_compiler.cpp index 498b76482364ef..1904d5f32c7dd8 100644 --- a/taichi/codegen/spirv/snode_struct_compiler.cpp +++ b/taichi/codegen/spirv/snode_struct_compiler.cpp @@ -94,6 +94,7 @@ class StructCompiler { cell_stride += snode_size; snode_descriptors_.find(ch_snode->id) ->second.mem_offset_in_parent_cell = child_offset; + ch_snode->offset_bytes_in_parent_cell = child_offset; } sn_desc.cell_stride = cell_stride; diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 62ea16f185959c..ea40d7e82dcdc8 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -219,40 +219,53 @@ class TaskCodegen : public IRVisitor { } void visit(AllocaStmt *alloca) override { + spirv::Value ptr_val; if (alloca->ret_type->is()) { - // Alloca for shared memory / workgroup memory - if (!alloca->is_shared) { - TI_ERROR( - "Tensor type for dyanmic index is not yet supported on Vulkan."); - } auto tensor_type = alloca->ret_type->cast(); auto elem_num = tensor_type->get_num_elements(); spirv::SType elem_type = ir_->get_primitive_type(tensor_type->get_element_type()); - spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num); - spirv::Value ptr_val = ir_->alloca_workgroup_array(arr_type); - shared_array_binds_.push_back(ptr_val); - ir_->register_value(alloca->raw_name(), ptr_val); + if (alloca->is_shared) { // for shared memory / workgroup memory + ptr_val = ir_->alloca_workgroup_array(arr_type); + shared_array_binds_.push_back(ptr_val); + } else { // for function memory + ptr_val = ir_->alloca_variable(arr_type); + } } else { // Alloca for a single variable spirv::SType src_type = ir_->get_primitive_type(alloca->element_type()); - spirv::Value ptr_val = ir_->alloca_variable(src_type); + ptr_val = ir_->alloca_variable(src_type); ir_->store_variable(ptr_val, ir_->get_zero(src_type)); - ir_->register_value(alloca->raw_name(), ptr_val); } + ir_->register_value(alloca->raw_name(), ptr_val); } void visit(MatrixPtrStmt *stmt) override { - spirv::SType data_type = - ir_->get_primitive_type(stmt->element_type().ptr_removed()); - spirv::SType ptr_type = - ir_->get_pointer_type(data_type, spv::StorageClassWorkgroup); - auto origin_val = ir_->query_value(stmt->origin->raw_name()); - auto offset_val = ir_->query_value(stmt->offset->raw_name()); - Value offset_ptr = - ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val); - ir_->register_value(stmt->raw_name(), offset_ptr); + spirv::Value ptr_val; + spirv::Value origin_val = ir_->query_value(stmt->origin->raw_name()); + spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); + auto dt = stmt->element_type().ptr_removed(); + if (stmt->offset_used_as_index()) { + if (stmt->origin->is()) { + spirv::SType ptr_type = ir_->get_pointer_type( + ir_->get_primitive_type(dt), origin_val.stype.storage_class); + ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, + offset_val); + } else if (stmt->origin->is()) { + spirv::Value dt_bytes = ir_->int_immediate_number( + ir_->i32_type(), ir_->get_primitive_type_size(dt), false); + spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val); + ptr_val = ir_->add(origin_val, offset_bytes); + ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; + } else { + TI_NOT_IMPLEMENTED; + } + } else { // offset used as bytes + ptr_val = ir_->add(origin_val, ir_->cast(origin_val.stype, offset_val)); + ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; + } + ir_->register_value(stmt->raw_name(), ptr_val); } void visit(LocalLoadStmt *stmt) override { diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 4ff20f3160ce6f..07e5ffbcc61359 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -23,8 +23,10 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::metal, {Extension::adstack, Extension::assertion, Extension::dynamic_index, Extension::sparse}}, - {Arch::opengl, {Extension::extfunc}}, + {Arch::opengl, {Extension::dynamic_index, Extension::extfunc}}, {Arch::gles, {}}, + {Arch::vulkan, {Extension::dynamic_index}}, + {Arch::dx11, {Extension::dynamic_index}}, {Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}}, }; // if (with_opengl_extension_data64()) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index bdbe050df8f608..e1a74ae3535a69 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -166,39 +166,33 @@ def run(): def _test_local_matrix_non_constant_index(): @ti.kernel - def func1(): + def func1() -> ti.types.vector(3, ti.i32): tmp = ti.Vector([1, 2, 3]) for i in range(3): vec = ti.Vector([4, 5, 6]) for j in range(3): vec[tmp[i] % 3] += vec[j] tmp[i] = vec[tmp[i] % 3] - assert tmp[0] == 24 - assert tmp[1] == 30 - assert tmp[2] == 19 + return tmp - func1() + assert (func1() == ti.Vector([24, 30, 19])).all() @ti.kernel - def func2(i: ti.i32, j: ti.i32, k: ti.i32): + def func2(i: ti.i32, j: ti.i32, k: ti.i32) -> ti.i32: tmp = ti.Matrix([[k, k * 2], [k * 2, k * 3]]) - assert tmp[i, j] == k * (i + j + 1) + return tmp[i, j] for i in range(2): for j in range(2): - func2(i, j, 10) + assert func2(i, j, 10) == 10 * (i + j + 1) -@test_utils.test(require=ti.extension.dynamic_index, - dynamic_index=True, - debug=True) +@test_utils.test(require=ti.extension.dynamic_index, dynamic_index=True) def test_local_matrix_non_constant_index(): _test_local_matrix_non_constant_index() -@test_utils.test(arch=[ti.cuda, ti.cpu], - real_matrix_scalarize=False, - debug=True) +@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix_scalarize=False) def test_local_matrix_non_constant_index_real_matrix(): _test_local_matrix_non_constant_index()