From c00cda3e0f12b709c006b12da5f546b89b3676a3 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 25 Oct 2022 18:04:52 +0800 Subject: [PATCH 1/6] [Lang] MatrixType refactor: Support matrix slice --- python/taichi/lang/impl.py | 22 ++++-- taichi/analysis/gen_offline_cache_key.cpp | 5 +- taichi/inc/statements.inc.h | 1 + taichi/ir/expression_printer.h | 13 +++- taichi/ir/frontend_ir.cpp | 84 ++++++++++++++--------- taichi/ir/frontend_ir.h | 16 ++++- taichi/ir/statements.cpp | 7 +- taichi/ir/statements.h | 15 ++++ taichi/python/export_lang.cpp | 2 + taichi/transforms/ir_printer.cpp | 12 ++++ taichi/transforms/lower_matrix_ptr.cpp | 18 ++++- tests/python/test_matrix_slice.py | 12 +++- 12 files changed, 163 insertions(+), 44 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 6c01ab59deb5f..747c78e3c6c13 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -182,7 +182,7 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): indices = () if has_slice: - if not isinstance(value, Matrix): + if not isinstance(value, Matrix) and not (isinstance(value, Expr) and value.is_tensor()): raise SyntaxError( f"The type {type(value)} do not support index of slice type") else: @@ -269,9 +269,23 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): if isinstance(value, Expr): # Index into TensorType # value: IndexExpression with ret_type = TensorType - assert current_cfg().real_matrix is True - assert is_tensor(value.ptr.get_ret_type()) - + assert current_cfg().real_matrix + assert value.is_tensor() + + if has_slice: + shape = value.get_shape() + dim = len(shape) + assert dim == len(indices) + indices = [_calc_slice(index, shape[i]) if isinstance(index, slice) + else [index] for i, index in enumerate(indices)] + if dim == 1: + multiple_indices = [make_expr_group(i) for i in indices[0]] + return_shape = (len(indices[0]), ) + else: + assert dim == 2 + multiple_indices = [make_expr_group(i, j) for i in indices[0] for j in indices[1]] + return_shape = (len(indices[0]), len(indices[1])) + return Expr(_ti_core.subscript_with_multiple_indices(value.ptr, multiple_indices, return_shape, get_runtime().get_current_src_info())) return Expr( _ti_core.subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 138c47fd1dd82..377451f51c066 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -169,7 +169,10 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(IndexExpression *expr) override { emit(ExprOpCode::IndexExpression); emit(expr->var); - emit(expr->indices.exprs); + for (auto &indices : expr->indices_group) { + emit(indices.exprs); + } + emit(expr->ret_shape); } void visit(MatrixExpression *expr) override { diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index a3e6045a9d2a6..2e73b54cbf562 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -60,6 +60,7 @@ PER_STATEMENT(GetChStmt) PER_STATEMENT(LocalLoadStmt) PER_STATEMENT(GlobalPtrStmt) PER_STATEMENT(MatrixOfGlobalPtrStmt) +PER_STATEMENT(MatrixOfMatrixPtrStmt) // Offloaded PER_STATEMENT(OffloadedStmt) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 16c85905bed77..19271ff335bc1 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -130,7 +130,18 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); - emit_vector(expr->indices.exprs); + if (expr->ret_shape.empty()) { + emit_vector(expr->indices_group[0].exprs); + } else { + for (auto &indices : expr->indices_group) { + emit('('); + emit_vector(indices.exprs); + emit("), "); + } + emit("shape=("); + emit_vector(expr->ret_shape); + emit(')'); + } emit(']'); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 926904a1e726d..f3431031e6f81 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -5,6 +5,8 @@ #include "taichi/program/program.h" #include "taichi/common/exceptions.h" +#include + namespace taichi::lang { #define TI_ASSERT_TYPE_CHECKED(x) \ @@ -572,17 +574,11 @@ Stmt *make_ndarray_access(Expression::FlattenContext *ctx, return ctx->push_back(std::move(external_ptr_stmt)); } -Stmt *make_tensor_access(Expression::FlattenContext *ctx, - Expr var, - ExprGroup indices, - std::vector shape, - int stride) { - flatten_lvalue(var, ctx); - if (!var->is_lvalue()) { - auto alloca_stmt = ctx->push_back(var->ret_type); - ctx->push_back(alloca_stmt, var->stmt); - var->stmt = alloca_stmt; - } +Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, + const Expr &var, + const ExprGroup &indices, + const std::vector &shape, + int stride) { bool needs_dynamic_index = false; for (int i = 0; i < (int)indices.size(); ++i) { if (!indices[i].is()) { @@ -616,6 +612,28 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +Stmt *make_tensor_access(Expression::FlattenContext *ctx, + Expr var, + const std::vector &indices_group, + DataType ret_type, + std::vector shape, + int stride) { + flatten_lvalue(var, ctx); + if (!var->is_lvalue()) { + auto alloca_stmt = ctx->push_back(var->ret_type); + ctx->push_back(alloca_stmt, var->stmt); + var->stmt = alloca_stmt; + } + if (is_tensor(ret_type)) { + std::vector stmts; + for (auto &indices : indices_group) { + stmts.push_back(make_tensor_access_single_element(ctx, var, indices, shape, stride)); + } + return ctx->push_back(stmts, ret_type); + } + return make_tensor_access_single_element(ctx, var, indices_group[0], shape, stride); +} + void MatrixExpression::type_check(CompileConfig *config) { // TODO: typecheck matrix for (auto &arg : elements) { @@ -671,7 +689,12 @@ bool IndexExpression::is_global() const { void IndexExpression::type_check(CompileConfig *) { // TODO: Change to type-based solution // Currently, dimension compatibility check happens in Python - if (is_field()) { // field + TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), end(ret_shape), 1, std::multiplies<>())); + if (!ret_shape.empty()) { + TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); + auto element_type = var->ret_type->as()->get_element_type(); + ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); + } else if (is_field()) { // field ret_type = var.cast()->dt->get_compute_type(); } else if (is_matrix_field()) { auto matrix_field_expr = var.cast(); @@ -682,7 +705,7 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_ndarray()) { // ndarray auto external_tensor_expr = var.cast(); int total_dim = external_tensor_expr->dim; - int index_dim = indices.exprs.size(); + int index_dim = indices_group[0].exprs.size(); if (index_dim == total_dim) { // Access all the way to a single element @@ -693,9 +716,9 @@ void IndexExpression::type_check(CompileConfig *) { } } else if (is_tensor()) { // local tensor auto shape = var->ret_type->as()->get_shape(); - if (indices.size() != shape.size()) { + if (indices_group[0].size() != shape.size()) { TI_ERROR("Expected {} indices, but got {}.", shape.size(), - indices.size()); + indices_group[0].size()); } ret_type = var->ret_type->cast()->get_element_type(); } else { @@ -704,28 +727,30 @@ void IndexExpression::type_check(CompileConfig *) { "local tensor"); } - for (int i = 0; i < indices.exprs.size(); i++) { - auto &expr = indices.exprs[i]; - TI_ASSERT_TYPE_CHECKED(expr); - if (!is_integral(expr->ret_type)) - throw TaichiTypeError( - fmt::format("indices must be integers, however '{}' is " - "provided as index {}", - expr->ret_type->to_string(), i)); + for (auto &indices : indices_group) { + for (int i = 0; i < indices.exprs.size(); i++) { + auto &expr = indices.exprs[i]; + TI_ASSERT_TYPE_CHECKED(expr); + if (!is_integral(expr->ret_type)) + throw TaichiTypeError( + fmt::format("indices must be integers, however '{}' is " + "provided as index {}", + expr->ret_type->to_string(), i)); + } } } void IndexExpression::flatten(FlattenContext *ctx) { if (is_field()) { - stmt = make_field_access(ctx, *var.cast(), indices); + stmt = make_field_access(ctx, *var.cast(), indices_group[0]); } else if (is_matrix_field()) { stmt = make_matrix_field_access(ctx, *var.cast(), - indices, ret_type); + indices_group[0], ret_type); } else if (is_ndarray()) { - stmt = make_ndarray_access(ctx, var, indices); + stmt = make_ndarray_access(ctx, var, indices_group[0]); } else if (is_tensor()) { stmt = make_tensor_access( - ctx, var, indices, var->ret_type->cast()->get_shape(), 1); + ctx, var, indices_group, ret_type, var->ret_type->cast()->get_shape(), 1); } else { throw TaichiTypeError( "Invalid IndexExpression: the source is not among field, ndarray or " @@ -746,7 +771,7 @@ void StrideExpression::type_check(CompileConfig *) { } void StrideExpression::flatten(FlattenContext *ctx) { - stmt = make_tensor_access(ctx, var, indices, shape, stride); + stmt = make_tensor_access(ctx, var, {indices}, ret_type, shape, stride); } void RangeAssumptionExpression::type_check(CompileConfig *) { @@ -1505,9 +1530,6 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { } } else if (ptr.is()) { flatten_global_load(ptr, ctx); - } else if (ptr.is()) { - TI_ASSERT(ptr.cast()->snode->num_active_indices == 0); - flatten_global_load(ptr[ExprGroup()], ctx); } else if (ptr.is() && ptr.cast()->is_ptr) { flatten_global_load(ptr, ctx); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 508a77e8284f4..7aac30acea7c2 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -579,12 +579,24 @@ class IndexExpression : public Expression { // `var` is one of FieldExpression, MatrixFieldExpression, // ExternalTensorExpression, IdExpression Expr var; - ExprGroup indices; + // In the cases of matrix slice and vector swizzle, there can be multiple + // indices, and the corresponding ret_shape should also be recorded. In normal + // index expressions ret_shape will be left empty. + std::vector indices_group; + std::vector ret_shape; IndexExpression(const Expr &var, const ExprGroup &indices, std::string tb = "") - : var(var), indices(indices) { + : var(var), indices_group({indices}) { + this->tb = tb; + } + + IndexExpression(const Expr &var, + const std::vector &indices_group, + const std::vector &ret_shape, + std::string tb = "") + : var(var), indices_group(indices_group), ret_shape(ret_shape) { this->tb = tb; } diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 1cc717934ae31..36ebc6b7c0376 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -78,11 +78,16 @@ MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector &snodes, TI_STMT_REG_FIELDS; } +MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { + ret_type = dt; + TI_STMT_REG_FIELDS; +} + MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, Stmt *offset_input) { origin = origin_input; offset = offset_input; if (origin->is() || origin->is() || - origin->is() || origin->is()) { + origin->is() || origin->is() || origin->is()) { auto tensor_type = origin->ret_type.ptr_removed()->cast(); TI_ASSERT(tensor_type != nullptr); element_type() = tensor_type->get_element_type(); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index e13f201040ceb..8847c57e57bcd 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -406,6 +406,21 @@ class MatrixOfGlobalPtrStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * A matrix of MatrixPtrStmts. The purpose of this stmt is to handle matrix + * slice and vector swizzle. This stmt will be eliminated after the + * lower_matrix_ptr pass. + */ +class MatrixOfMatrixPtrStmt : public Stmt { + public: + std::vector stmts; + + MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt); + + TI_STMT_DEF_FIELDS(ret_type, stmts); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * A pointer to an element of a matrix. */ diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index b52eee1d04f42..bc282f0606ddb 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1006,6 +1006,8 @@ void export_lang(py::module &m) { return idx_expr; }); + m.def("subscript_with_multiple_indices", Expr::make &, const std::vector &, std::string>); + m.def("make_stride_expr", Expr::make &, int>); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 090453ef742f6..231cce1482661 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -426,6 +426,18 @@ class IRPrinter : public IRVisitor { print_raw(s); } + void visit(MatrixOfMatrixPtrStmt *stmt) override { + std::string s = fmt::format("{}{} = matrix of matrix ptr [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < (int)stmt->stmts.size(); i++) { + s += fmt::format("{}", stmt->stmts[i]->name()); + if (i + 1 < (int)stmt->stmts.size()) { + s += ", "; + } + } + s += "]"; + print_raw(s); + } + void visit(MatrixPtrStmt *stmt) override { std::string s = fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(), diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index 9b06f7be0b611..4a639293bdf78 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -70,6 +70,14 @@ class LowerMatrixPtr : public BasicStmtVisitor { modifier_.erase(stmt); return; } + if (stmt->origin->is()) { + auto origin = stmt->origin->as(); + TI_ASSERT(stmt->offset->is()); + auto offset = stmt->offset->as(); + stmt->replace_usages_with(origin->stmts[offset->val.val_int()]); + modifier_.erase(stmt); + return; + } } static void run(IRNode *node) { @@ -79,7 +87,7 @@ class LowerMatrixPtr : public BasicStmtVisitor { } }; -class RemoveMatrixOfGlobalPtr : public BasicStmtVisitor { +class RemoveMatrixOfPtr : public BasicStmtVisitor { private: using BasicStmtVisitor::visit; DelayedIRModifier modifier_; @@ -89,8 +97,12 @@ class RemoveMatrixOfGlobalPtr : public BasicStmtVisitor { modifier_.erase(stmt); } + void visit(MatrixOfMatrixPtrStmt *stmt) override { + modifier_.erase(stmt); + } + static void run(IRNode *node) { - RemoveMatrixOfGlobalPtr pass; + RemoveMatrixOfPtr pass; node->accept(&pass); pass.modifier_.modify_ir(); } @@ -101,7 +113,7 @@ namespace irpass { void lower_matrix_ptr(IRNode *root) { TI_AUTO_PROF; LowerMatrixPtr::run(root); - RemoveMatrixOfGlobalPtr::run(root); + RemoveMatrixOfPtr::run(root); } } // namespace irpass diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index e5ce40fca5d52..84a77eab7936d 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -5,7 +5,7 @@ @test_utils.test() -def test_matrix_slice_read(): +def _test_matrix_slice_read(): b = 6 @ti.kernel @@ -28,6 +28,16 @@ def foo2() -> ti.types.matrix(2, 3, dtype=ti.i32): assert (m2 == ti.Matrix([[3]])).all() +@test_utils.test() +def test_matrix_slice_read(): + _test_matrix_slice_read() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_matrix_slice_read_real_matrix_scalarize(): + _test_matrix_slice_read() + + @test_utils.test() def test_matrix_slice_invalid(): @ti.kernel From 62a24c1a70d45798f20c446ba0eb11fd6fd66d3d Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 25 Oct 2022 19:48:43 +0800 Subject: [PATCH 2/6] Add tests --- tests/python/test_matrix_slice.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index 84a77eab7936d..51247b6558d01 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -4,7 +4,6 @@ from tests import test_utils -@test_utils.test() def _test_matrix_slice_read(): b = 6 @@ -38,8 +37,7 @@ def test_matrix_slice_read_real_matrix_scalarize(): _test_matrix_slice_read() -@test_utils.test() -def test_matrix_slice_invalid(): +def _test_matrix_slice_invalid(): @ti.kernel def foo1(i: ti.i32): a = ti.Vector([0, 1, 2, 3, 4, 5, 6]) @@ -59,8 +57,17 @@ def foo2(): foo2() -@test_utils.test(dynamic_index=True) -def test_matrix_slice_with_variable(): +@test_utils.test() +def test_matrix_slice_invalid(): + _test_matrix_slice_invalid() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_matrix_slice_invalid_real_matrix_scalarize(): + _test_matrix_slice_invalid() + + +def _test_matrix_slice_with_variable(): @ti.kernel def test_one_row_slice() -> ti.types.matrix(2, 1, dtype=ti.i32): m = ti.Matrix([[1, 2, 3], [4, 5, 6]]) @@ -79,6 +86,16 @@ def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32): assert (c1 == ti.Matrix([[4, 5, 6]])).all() +@test_utils.test(dynamic_index=True) +def test_matrix_slice_with_variable(): + _test_matrix_slice_with_variable() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True, dynamic_index=True) +def test_matrix_slice_with_variable_real_matrix_scalarize(): + _test_matrix_slice_with_variable() + + @test_utils.test(dynamic_index=False) def test_matrix_slice_with_variable_invalid(): @ti.kernel From bf7156797df690718b07a111f02afcaf30ad5997 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Oct 2022 12:13:58 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/impl.py | 20 +++++++++++++++----- taichi/ir/frontend_ir.cpp | 18 ++++++++++++------ taichi/ir/statements.cpp | 7 +++++-- taichi/python/export_lang.cpp | 5 ++++- taichi/transforms/ir_printer.cpp | 3 ++- tests/python/test_matrix_slice.py | 4 +++- 6 files changed, 41 insertions(+), 16 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 747c78e3c6c13..6ecaaa742ab0a 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -182,7 +182,8 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): indices = () if has_slice: - if not isinstance(value, Matrix) and not (isinstance(value, Expr) and value.is_tensor()): + if not isinstance(value, Matrix) and not (isinstance(value, Expr) + and value.is_tensor()): raise SyntaxError( f"The type {type(value)} do not support index of slice type") else: @@ -276,16 +277,25 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): shape = value.get_shape() dim = len(shape) assert dim == len(indices) - indices = [_calc_slice(index, shape[i]) if isinstance(index, slice) - else [index] for i, index in enumerate(indices)] + indices = [ + _calc_slice(index, shape[i]) + if isinstance(index, slice) else [index] + for i, index in enumerate(indices) + ] if dim == 1: multiple_indices = [make_expr_group(i) for i in indices[0]] return_shape = (len(indices[0]), ) else: assert dim == 2 - multiple_indices = [make_expr_group(i, j) for i in indices[0] for j in indices[1]] + multiple_indices = [ + make_expr_group(i, j) for i in indices[0] + for j in indices[1] + ] return_shape = (len(indices[0]), len(indices[1])) - return Expr(_ti_core.subscript_with_multiple_indices(value.ptr, multiple_indices, return_shape, get_runtime().get_current_src_info())) + return Expr( + _ti_core.subscript_with_multiple_indices( + value.ptr, multiple_indices, return_shape, + get_runtime().get_current_src_info())) return Expr( _ti_core.subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index f3431031e6f81..dcf5e4f47c247 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -627,11 +627,13 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, if (is_tensor(ret_type)) { std::vector stmts; for (auto &indices : indices_group) { - stmts.push_back(make_tensor_access_single_element(ctx, var, indices, shape, stride)); + stmts.push_back( + make_tensor_access_single_element(ctx, var, indices, shape, stride)); } return ctx->push_back(stmts, ret_type); } - return make_tensor_access_single_element(ctx, var, indices_group[0], shape, stride); + return make_tensor_access_single_element(ctx, var, indices_group[0], shape, + stride); } void MatrixExpression::type_check(CompileConfig *config) { @@ -689,7 +691,9 @@ bool IndexExpression::is_global() const { void IndexExpression::type_check(CompileConfig *) { // TODO: Change to type-based solution // Currently, dimension compatibility check happens in Python - TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), end(ret_shape), 1, std::multiplies<>())); + TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), + end(ret_shape), 1, + std::multiplies<>())); if (!ret_shape.empty()) { TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); auto element_type = var->ret_type->as()->get_element_type(); @@ -742,15 +746,17 @@ void IndexExpression::type_check(CompileConfig *) { void IndexExpression::flatten(FlattenContext *ctx) { if (is_field()) { - stmt = make_field_access(ctx, *var.cast(), indices_group[0]); + stmt = + make_field_access(ctx, *var.cast(), indices_group[0]); } else if (is_matrix_field()) { stmt = make_matrix_field_access(ctx, *var.cast(), indices_group[0], ret_type); } else if (is_ndarray()) { stmt = make_ndarray_access(ctx, var, indices_group[0]); } else if (is_tensor()) { - stmt = make_tensor_access( - ctx, var, indices_group, ret_type, var->ret_type->cast()->get_shape(), 1); + stmt = + make_tensor_access(ctx, var, indices_group, ret_type, + var->ret_type->cast()->get_shape(), 1); } else { throw TaichiTypeError( "Invalid IndexExpression: the source is not among field, ndarray or " diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 36ebc6b7c0376..e24c22238248e 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -78,7 +78,9 @@ MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector &snodes, TI_STMT_REG_FIELDS; } -MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { +MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, + DataType dt) + : stmts(stmts) { ret_type = dt; TI_STMT_REG_FIELDS; } @@ -87,7 +89,8 @@ MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, Stmt *offset_input) { origin = origin_input; offset = offset_input; if (origin->is() || origin->is() || - origin->is() || origin->is() || origin->is()) { + origin->is() || origin->is() || + origin->is()) { auto tensor_type = origin->ret_type.ptr_removed()->cast(); TI_ASSERT(tensor_type != nullptr); element_type() = tensor_type->get_element_type(); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index bc282f0606ddb..958fdb128652f 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1006,7 +1006,10 @@ void export_lang(py::module &m) { return idx_expr; }); - m.def("subscript_with_multiple_indices", Expr::make &, const std::vector &, std::string>); + m.def( + "subscript_with_multiple_indices", + Expr::make &, + const std::vector &, std::string>); m.def("make_stride_expr", Expr::maketype_hint(), stmt->name()); + std::string s = fmt::format("{}{} = matrix of matrix ptr [", + stmt->type_hint(), stmt->name()); for (int i = 0; i < (int)stmt->stmts.size(); i++) { s += fmt::format("{}", stmt->stmts[i]->name()); if (i + 1 < (int)stmt->stmts.size()) { diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index 51247b6558d01..aff516ab9ae3f 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -91,7 +91,9 @@ def test_matrix_slice_with_variable(): _test_matrix_slice_with_variable() -@test_utils.test(real_matrix=True, real_matrix_scalarize=True, dynamic_index=True) +@test_utils.test(real_matrix=True, + real_matrix_scalarize=True, + dynamic_index=True) def test_matrix_slice_with_variable_real_matrix_scalarize(): _test_matrix_slice_with_variable() From 5327f76c9a953b0c0a0e7e6c0d98c3966852894c Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 25 Oct 2022 20:17:29 +0800 Subject: [PATCH 4/6] Remove unused import --- python/taichi/lang/impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 6ecaaa742ab0a..dcc00cb8a215d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -29,7 +29,6 @@ python_scope, taichi_scope, warning) from taichi.types.primitive_types import (all_types, f16, f32, f64, i32, i64, u8, u32, u64) -from taichi.types.utils import is_tensor @taichi_scope From 1fde2536ac316c603253a2a2cf7070a8beb20e8f Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 26 Oct 2022 13:31:35 +0800 Subject: [PATCH 5/6] Update taichi/ir/statements.h Co-authored-by: Zhanlue Yang --- taichi/ir/statements.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 8847c57e57bcd..2516842fb0d12 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -410,6 +410,8 @@ class MatrixOfGlobalPtrStmt : public Stmt { * A matrix of MatrixPtrStmts. The purpose of this stmt is to handle matrix * slice and vector swizzle. This stmt will be eliminated after the * lower_matrix_ptr pass. + * + * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt operations even with real_matrix_scalarize=False */ class MatrixOfMatrixPtrStmt : public Stmt { public: From ebbb43c4bb4a74e6f42fd3c4e771c903a5ed4295 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Oct 2022 05:33:02 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/statements.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 2516842fb0d12..4bd88bab39cbf 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -411,7 +411,8 @@ class MatrixOfGlobalPtrStmt : public Stmt { * slice and vector swizzle. This stmt will be eliminated after the * lower_matrix_ptr pass. * - * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt operations even with real_matrix_scalarize=False + * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt + * operations even with real_matrix_scalarize=False */ class MatrixOfMatrixPtrStmt : public Stmt { public: