diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index 44530daadec2b..933b706aef033 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -248,8 +248,14 @@ class KernelCodegenImpl : public IRVisitor { } void visit(AllocaStmt *alloca) override { - emit(R"({} {}(0);)", metal_data_type_name(alloca->element_type()), - alloca->raw_name()); + if (alloca->ret_type->is()) { + auto tensor_type = alloca->ret_type->as(); + emit("{} {}[{}];", metal_data_type_name(tensor_type->get_element_type()), + alloca->raw_name(), tensor_type->get_num_elements()); + } else { + emit(R"({} {}(0);)", metal_data_type_name(alloca->element_type()), + alloca->raw_name()); + } } void visit(ConstStmt *const_stmt) override { @@ -437,6 +443,23 @@ class KernelCodegenImpl : public IRVisitor { } } + void visit(MatrixPtrStmt *stmt) override { + const auto dt = stmt->origin->ret_type.ptr_removed().get_element_type(); + if (stmt->offset_used_as_index()) { + const auto fmt_str = stmt->origin->is() + ? "thread {}& {} = {}[{}];" + : "device {}* {} = {} + {};"; + emit(fmt_str, metal_data_type_name(dt), stmt->raw_name(), + stmt->origin->raw_name(), stmt->offset->raw_name()); + } else { // offset used as bytes + emit( + "device {}* {} = reinterpret_cast(reinterpret_cast({}) + {});", + metal_data_type_name(dt), stmt->raw_name(), metal_data_type_name(dt), + stmt->origin->raw_name(), stmt->offset->raw_name()); + } + } + void visit(ExternalPtrStmt *stmt) override { // Used mostly for transferring data between host (e.g. numpy array) and // Metal. @@ -484,7 +507,8 @@ class KernelCodegenImpl : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { - const auto dt = metal_data_type_name(stmt->element_type().ptr_removed()); + const auto dt = metal_data_type_name( + stmt->element_type().ptr_removed().get_element_type()); emit("device {}* {} = reinterpret_cast({} + {});", dt, stmt->raw_name(), dt, kGlobalTmpsBufferName, stmt->offset); } diff --git a/taichi/codegen/metal/struct_metal.cpp b/taichi/codegen/metal/struct_metal.cpp index d583488da4c8d..9365bec2b5b78 100644 --- a/taichi/codegen/metal/struct_metal.cpp +++ b/taichi/codegen/metal/struct_metal.cpp @@ -289,7 +289,7 @@ class StructCompiler { emit(""); } - size_t compute_snode_size(const SNode *sn) { + size_t compute_snode_size(SNode *sn) { if (sn->is_place()) { return metal_data_type_bytes(to_metal_type(sn->dt)); } @@ -312,12 +312,13 @@ class StructCompiler { } else { for (const auto &ch : sn->ch) { const size_t ch_offset = ch_size; - const auto *ch_sn = ch.get(); + auto *ch_sn = ch.get(); ch_size += compute_snode_size(ch_sn); if (!ch_sn->is_place()) { snode_descriptors_.find(ch_sn->id)->second.mem_offset_in_parent = ch_offset; } + ch_sn->offset_bytes_in_parent_cell = ch_offset; } } @@ -341,6 +342,7 @@ class StructCompiler { TI_ASSERT(snode_descriptors_.find(sn->id) == snode_descriptors_.end()); snode_descriptors_[sn->id] = sn_desc; + sn->cell_size_bytes = sn_desc.stride; return sn_desc.stride; } diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 1c7b86315ffb5..3536c9c35dbba 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -21,7 +21,8 @@ bool is_extension_supported(Arch arch, Extension ext) { Extension::assertion, Extension::dynamic_index, Extension::mesh}}, // TODO: supporting quant in metal(tests randomly crashed) {Arch::metal, - {Extension::adstack, Extension::assertion, Extension::sparse}}, + {Extension::adstack, Extension::assertion, Extension::dynamic_index, + Extension::sparse}}, {Arch::opengl, {Extension::extfunc}}, {Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}}, }; diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 26c22f85117f9..bf55ed1538d5c 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -517,6 +517,17 @@ class Scalarize : public BasicStmtVisitor { scalarize_load_stmt(stmt); } + void visit(ArgLoadStmt *stmt) override { + auto ret_type = stmt->ret_type.ptr_removed().get_element_type(); + auto arg_load = + std::make_unique(stmt->arg_id, ret_type, stmt->is_ptr); + + stmt->replace_usages_with(arg_load.get()); + + modifier_.insert_before(stmt, std::move(arg_load)); + modifier_.erase(stmt); + } + private: using BasicStmtVisitor::visit; }; @@ -633,17 +644,6 @@ class ScalarizePointers : public BasicStmtVisitor { } } - void visit(ArgLoadStmt *stmt) override { - auto ret_type = stmt->ret_type.ptr_removed().get_element_type(); - auto arg_load = - std::make_unique(stmt->arg_id, ret_type, stmt->is_ptr); - - stmt->replace_usages_with(arg_load.get()); - - modifier_.insert_before(stmt, std::move(arg_load)); - modifier_.erase(stmt); - } - private: using BasicStmtVisitor::visit; }; diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 2b21f0f3501b2..bdbe050df8f60 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -435,7 +435,7 @@ def run(): assert v[i][j] == i * j -@test_utils.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(require=ti.extension.dynamic_index) def test_matrix_field_dynamic_index_different_path_length(): v = ti.Vector.field(2, ti.i32) x = v.get_scalar_field(0) @@ -448,7 +448,7 @@ def test_matrix_field_dynamic_index_different_path_length(): assert v._get_dynamic_index_stride() is None -@test_utils.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(require=ti.extension.dynamic_index) def test_matrix_field_dynamic_index_not_pure_dense(): v = ti.Vector.field(2, ti.i32) x = v.get_scalar_field(0) @@ -461,7 +461,7 @@ def test_matrix_field_dynamic_index_not_pure_dense(): assert v._get_dynamic_index_stride() is None -@test_utils.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(require=ti.extension.dynamic_index) def test_matrix_field_dynamic_index_different_cell_size_bytes(): temp = ti.field(ti.f32) @@ -476,7 +476,7 @@ def test_matrix_field_dynamic_index_different_cell_size_bytes(): assert v._get_dynamic_index_stride() is None -@test_utils.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(require=ti.extension.dynamic_index) def test_matrix_field_dynamic_index_different_offset_bytes_in_parent_cell(): temp_a = ti.field(ti.f32) temp_b = ti.field(ti.f32) @@ -492,7 +492,7 @@ def test_matrix_field_dynamic_index_different_offset_bytes_in_parent_cell(): assert v._get_dynamic_index_stride() is None -@test_utils.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(require=ti.extension.dynamic_index) def test_matrix_field_dynamic_index_different_stride(): temp = ti.field(ti.f32) @@ -507,7 +507,7 @@ def test_matrix_field_dynamic_index_different_stride(): assert v._get_dynamic_index_stride() is None -@test_utils.test(arch=[ti.cpu, ti.cuda], dynamic_index=True) +@test_utils.test(require=ti.extension.dynamic_index, dynamic_index=True) def test_matrix_field_dynamic_index_multiple_materialize(): @ti.kernel def empty(): @@ -529,7 +529,9 @@ def func(): assert a[i][j] == (i if j == i % 3 else 0) -@test_utils.test(arch=[ti.cpu, ti.cuda], dynamic_index=True, debug=True) +@test_utils.test(require=ti.extension.dynamic_index, + dynamic_index=True, + debug=True) def test_local_vector_initialized_in_a_loop(): @ti.kernel def foo(): diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index 06dd0a4b81049..563ffdfe448f1 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -49,23 +49,23 @@ def foo2(): foo2() -@test_utils.test(dynamic_index=True) +@test_utils.test(require=ti.extension.dynamic_index, dynamic_index=True) def test_matrix_slice_with_variable(): @ti.kernel - def test_one_row_slice() -> ti.types.matrix(2, 1, dtype=ti.i32): + def test_one_row_slice( + index: ti.i32) -> ti.types.matrix(2, 1, dtype=ti.i32): m = ti.Matrix([[1, 2, 3], [4, 5, 6]]) - index = 1 return m[:, index] @ti.kernel - def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32): + def test_one_col_slice( + index: ti.i32) -> ti.types.matrix(1, 3, dtype=ti.i32): m = ti.Matrix([[1, 2, 3], [4, 5, 6]]) - index = 1 return m[index, :] - r1 = test_one_row_slice() + r1 = test_one_row_slice(1) assert (r1 == ti.Matrix([[2], [5]])).all() - c1 = test_one_col_slice() + c1 = test_one_col_slice(1) assert (c1 == ti.Matrix([[4, 5, 6]])).all()