Skip to content

Commit

Permalink
[test] MatrixType refactor: Add tests for writing to matrix slice (ta…
Browse files Browse the repository at this point in the history
…ichi-dev#6480)

Issue: taichi-dev#5819

### Brief Summary

This is a follow-up PR of taichi-dev#6430.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent c0d43cc commit 1577139
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
5 changes: 2 additions & 3 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,8 @@ class LocalStoreStmt : public Stmt {
Stmt *val;

LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) {
TI_ASSERT(dest->is<AllocaStmt>() ||
(dest->is<MatrixPtrStmt>() &&
dest->cast<MatrixPtrStmt>()->offset_used_as_index()));
TI_ASSERT(dest->is<AllocaStmt>() || dest->is<MatrixPtrStmt>() ||
dest->is<MatrixOfMatrixPtrStmt>());
TI_STMT_REG_FIELDS;
}

Expand Down
79 changes: 52 additions & 27 deletions tests/python/test_matrix_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,64 @@ def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32):
test_one_col_slice()


@test_utils.test(debug=True)
def test_matrix_slice_write():
def _test_matrix_slice_write():
@ti.kernel
def foo():
m = ti.Matrix([[0., 0., 0., 0.] for _ in range(3)])
vec = ti.Vector([1., 2., 3., 4.])
m[0, :] = vec.transpose()
ref = ti.Matrix([[1., 2., 3., 4.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
assert all(m == ref)
def assign_row() -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
row = ti.Matrix([[1, 2, 3, 4]])
mat[0, :] = row
return mat

m[1, 1:3] = ti.Vector([1., 2.]).transpose()
ref = ti.Matrix([[1., 2., 3., 4.], [0., 1., 2., 0.], [0., 0., 0., 0.]])
assert all(m == ref)
@ti.kernel
def assign_partial_row() -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
mat[1, 1:3] = ti.Matrix([[1, 2]])
return mat

m1 = ti.Matrix([[1., 1., 1., 1.] for _ in range(2)])
m[:2, :] += m1
ref = ti.Matrix([[2., 3., 4., 5.], [1., 2., 3., 1.], [0., 0., 0., 0.]])
assert all(m == ref)
@ti.kernel
def augassign_rows() -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[1, 1, 1, 1] for _ in range(3)])
rows = ti.Matrix([[1, 2, 3, 4] for _ in range(2)])
mat[:2, :] += rows
return mat

foo()
assert (assign_row() == ti.Matrix([[1, 2, 3, 4], [0, 0, 0, 0],
[0, 0, 0, 0]])).all()
assert (assign_partial_row() == ti.Matrix([[0, 0, 0, 0], [0, 1, 2, 0],
[0, 0, 0, 0]])).all()
assert (augassign_rows() == ti.Matrix([[2, 3, 4, 5], [2, 3, 4, 5],
[1, 1, 1, 1]])).all()


@test_utils.test(debug=True, dynamic_index=True)
def test_matrix_slice_write_dynamic_index():
@test_utils.test()
def test_matrix_slice_write():
_test_matrix_slice_write()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_matrix_slice_write_real_matrix_scalarize():
_test_matrix_slice_write()


def _test_matrix_slice_write_dynamic_index():
@ti.kernel
def foo(i: ti.i32, ref: ti.template()):
m = ti.Matrix([[0., 0., 0., 0.] for _ in range(3)])
vec = ti.Vector([1., 2., 3., 4.])
m[i, :] = vec.transpose()
assert all(m == ref)
def foo(i: ti.i32) -> ti.types.matrix(3, 4, ti.i32):
mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)])
mat[i, :] = ti.Matrix([[1, 2, 3, 4]])
return mat

for i in range(3):
foo(
i,
ti.Matrix([[1., 2., 3., 4.] if j == i else [0., 0., 0., 0.]
for j in range(3)]))
assert (foo(i) == ti.Matrix([[1, 2, 3, 4] if j == i else [0, 0, 0, 0]
for j in range(3)])).all()


@test_utils.test(dynamic_index=True)
def test_matrix_slice_write_dynamic_index():
_test_matrix_slice_write_dynamic_index()


@test_utils.test(real_matrix=True,
real_matrix_scalarize=True,
dynamic_index=True)
def test_matrix_slice_write_dynamic_index_real_matrix_scalarize():
_test_matrix_slice_write_dynamic_index()

0 comments on commit 1577139

Please sign in to comment.