Skip to content

Commit

Permalink
[Lang] MatrixField refactor 9/n: Allow dynamic index of matrix field …
Browse files Browse the repository at this point in the history
…when real_matrix=True (#6194)

Issue: #5959

### Brief Summary

This PR makes the rightmost path of #5959 running end-to-end.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Sep 29, 2022
1 parent f56ea84 commit 50e5e3f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
34 changes: 27 additions & 7 deletions taichi/transforms/lower_matrix_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,34 @@ class LowerMatrixPtr : public BasicStmtVisitor {

void visit(MatrixPtrStmt *stmt) override {
if (stmt->origin->is<MatrixOfGlobalPtrStmt>()) {
TI_ASSERT(stmt->offset->is<ConstStmt>());
auto origin = stmt->origin->as<MatrixOfGlobalPtrStmt>();
auto offset = stmt->offset->as<ConstStmt>();
auto lowered = std::make_unique<GlobalPtrStmt>(
origin->snodes[offset->val.val_int()], origin->indices);
stmt->replace_usages_with(lowered.get());
modifier.insert_before(stmt, std::move(lowered));
modifier.erase(stmt);
if (stmt->offset->is<ConstStmt>()) {
auto offset = stmt->offset->as<ConstStmt>();
auto lowered = std::make_unique<GlobalPtrStmt>(
origin->snodes[offset->val.val_int()], origin->indices);
stmt->replace_usages_with(lowered.get());
modifier.insert_before(stmt, std::move(lowered));
modifier.erase(stmt);
} else {
TI_ASSERT_INFO(
origin->dynamic_indexable,
"Element of the MatrixField is not dynamic indexable.\n{}",
stmt->tb);
auto stride = std::make_unique<ConstStmt>(
TypedConstant(origin->dynamic_index_stride));
auto offset = std::make_unique<BinaryOpStmt>(
BinaryOpType::mul, stmt->offset, stride.get());
auto ptr_base =
std::make_unique<GlobalPtrStmt>(origin->snodes[0], origin->indices);
auto lowered =
std::make_unique<MatrixPtrStmt>(ptr_base.get(), offset.get());
stmt->replace_usages_with(lowered.get());
modifier.insert_before(stmt, std::move(stride));
modifier.insert_before(stmt, std::move(offset));
modifier.insert_before(stmt, std::move(ptr_base));
modifier.insert_before(stmt, std::move(lowered));
modifier.erase(stmt);
}
}
}

Expand Down
19 changes: 15 additions & 4 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ def func2(b: ti.types.ndarray(element_dim=1)):
assert v[3][9] == 9


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
def test_matrix_non_constant_index():
def _test_matrix_non_constant_index():
m = ti.Matrix.field(2, 2, ti.i32, 5)
v = ti.Vector.field(10, ti.i32, 5)

Expand Down Expand Up @@ -248,6 +245,20 @@ def func4(k: ti.i32):
func4(10)


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
def test_matrix_non_constant_index():
_test_matrix_non_constant_index()


@test_utils.test(require=ti.extension.dynamic_index,
real_matrix=True,
debug=True)
def test_matrix_non_constant_index_real_matrix():
_test_matrix_non_constant_index()


def _test_matrix_constant_index():
m = ti.Matrix.field(2, 2, ti.i32, 5)

Expand Down

0 comments on commit 50e5e3f

Please sign in to comment.