From 50e5e3fda7baa602250e443be324c74e82a3d393 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 29 Sep 2022 10:29:27 +0800 Subject: [PATCH] [Lang] MatrixField refactor 9/n: Allow dynamic index of matrix field when real_matrix=True (#6194) Issue: #5959 ### Brief Summary This PR makes the rightmost path of #5959 running end-to-end. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/transforms/lower_matrix_ptr.cpp | 34 ++++++++++++++++++++------ tests/python/test_matrix.py | 19 +++++++++++--- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index 69f3721e72181..de385457a572e 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -13,14 +13,34 @@ class LowerMatrixPtr : public BasicStmtVisitor { void visit(MatrixPtrStmt *stmt) override { if (stmt->origin->is()) { - TI_ASSERT(stmt->offset->is()); auto origin = stmt->origin->as(); - auto offset = stmt->offset->as(); - auto lowered = std::make_unique( - origin->snodes[offset->val.val_int()], origin->indices); - stmt->replace_usages_with(lowered.get()); - modifier.insert_before(stmt, std::move(lowered)); - modifier.erase(stmt); + if (stmt->offset->is()) { + auto offset = stmt->offset->as(); + auto lowered = std::make_unique( + origin->snodes[offset->val.val_int()], origin->indices); + stmt->replace_usages_with(lowered.get()); + modifier.insert_before(stmt, std::move(lowered)); + modifier.erase(stmt); + } else { + TI_ASSERT_INFO( + origin->dynamic_indexable, + "Element of the MatrixField is not dynamic indexable.\n{}", + stmt->tb); + auto stride = std::make_unique( + TypedConstant(origin->dynamic_index_stride)); + auto offset = std::make_unique( + BinaryOpType::mul, stmt->offset, stride.get()); + auto ptr_base = + std::make_unique(origin->snodes[0], origin->indices); + auto lowered = + std::make_unique(ptr_base.get(), offset.get()); + stmt->replace_usages_with(lowered.get()); + modifier.insert_before(stmt, std::move(stride)); + modifier.insert_before(stmt, std::move(offset)); + modifier.insert_before(stmt, std::move(ptr_base)); + modifier.insert_before(stmt, std::move(lowered)); + modifier.erase(stmt); + } } } diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 74ddb3f706617..3cf75032ba5cc 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -193,10 +193,7 @@ def func2(b: ti.types.ndarray(element_dim=1)): assert v[3][9] == 9 -@test_utils.test(require=ti.extension.dynamic_index, - dynamic_index=True, - debug=True) -def test_matrix_non_constant_index(): +def _test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) v = ti.Vector.field(10, ti.i32, 5) @@ -248,6 +245,20 @@ def func4(k: ti.i32): func4(10) +@test_utils.test(require=ti.extension.dynamic_index, + dynamic_index=True, + debug=True) +def test_matrix_non_constant_index(): + _test_matrix_non_constant_index() + + +@test_utils.test(require=ti.extension.dynamic_index, + real_matrix=True, + debug=True) +def test_matrix_non_constant_index_real_matrix(): + _test_matrix_non_constant_index() + + def _test_matrix_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5)