From 15771392eb19b4842902eda5572b2df2dedc8fd7 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Mon, 31 Oct 2022 17:21:54 +0800 Subject: [PATCH] [test] MatrixType refactor: Add tests for writing to matrix slice (#6480) Issue: #5819 ### Brief Summary This is a follow-up PR of #6430. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/ir/statements.h | 5 +- tests/python/test_matrix_slice.py | 79 ++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index d1c3c067f9de5..3501717ac88c9 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -703,9 +703,8 @@ class LocalStoreStmt : public Stmt { Stmt *val; LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { - TI_ASSERT(dest->is() || - (dest->is() && - dest->cast()->offset_used_as_index())); + TI_ASSERT(dest->is() || dest->is() || + dest->is()); TI_STMT_REG_FIELDS; } diff --git a/tests/python/test_matrix_slice.py b/tests/python/test_matrix_slice.py index aff516ab9ae3f..4765beede1d81 100644 --- a/tests/python/test_matrix_slice.py +++ b/tests/python/test_matrix_slice.py @@ -122,39 +122,64 @@ def test_one_col_slice() -> ti.types.matrix(1, 3, dtype=ti.i32): test_one_col_slice() -@test_utils.test(debug=True) -def test_matrix_slice_write(): +def _test_matrix_slice_write(): @ti.kernel - def foo(): - m = ti.Matrix([[0., 0., 0., 0.] for _ in range(3)]) - vec = ti.Vector([1., 2., 3., 4.]) - m[0, :] = vec.transpose() - ref = ti.Matrix([[1., 2., 3., 4.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) - assert all(m == ref) + def assign_row() -> ti.types.matrix(3, 4, ti.i32): + mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) + row = ti.Matrix([[1, 2, 3, 4]]) + mat[0, :] = row + return mat - m[1, 1:3] = ti.Vector([1., 2.]).transpose() - ref = ti.Matrix([[1., 2., 3., 4.], [0., 1., 2., 0.], [0., 0., 0., 0.]]) - assert all(m == ref) + @ti.kernel + def assign_partial_row() -> ti.types.matrix(3, 4, ti.i32): + mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) + mat[1, 1:3] = ti.Matrix([[1, 2]]) + return mat - m1 = ti.Matrix([[1., 1., 1., 1.] for _ in range(2)]) - m[:2, :] += m1 - ref = ti.Matrix([[2., 3., 4., 5.], [1., 2., 3., 1.], [0., 0., 0., 0.]]) - assert all(m == ref) + @ti.kernel + def augassign_rows() -> ti.types.matrix(3, 4, ti.i32): + mat = ti.Matrix([[1, 1, 1, 1] for _ in range(3)]) + rows = ti.Matrix([[1, 2, 3, 4] for _ in range(2)]) + mat[:2, :] += rows + return mat - foo() + assert (assign_row() == ti.Matrix([[1, 2, 3, 4], [0, 0, 0, 0], + [0, 0, 0, 0]])).all() + assert (assign_partial_row() == ti.Matrix([[0, 0, 0, 0], [0, 1, 2, 0], + [0, 0, 0, 0]])).all() + assert (augassign_rows() == ti.Matrix([[2, 3, 4, 5], [2, 3, 4, 5], + [1, 1, 1, 1]])).all() -@test_utils.test(debug=True, dynamic_index=True) -def test_matrix_slice_write_dynamic_index(): +@test_utils.test() +def test_matrix_slice_write(): + _test_matrix_slice_write() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_matrix_slice_write_real_matrix_scalarize(): + _test_matrix_slice_write() + + +def _test_matrix_slice_write_dynamic_index(): @ti.kernel - def foo(i: ti.i32, ref: ti.template()): - m = ti.Matrix([[0., 0., 0., 0.] for _ in range(3)]) - vec = ti.Vector([1., 2., 3., 4.]) - m[i, :] = vec.transpose() - assert all(m == ref) + def foo(i: ti.i32) -> ti.types.matrix(3, 4, ti.i32): + mat = ti.Matrix([[0, 0, 0, 0] for _ in range(3)]) + mat[i, :] = ti.Matrix([[1, 2, 3, 4]]) + return mat for i in range(3): - foo( - i, - ti.Matrix([[1., 2., 3., 4.] if j == i else [0., 0., 0., 0.] - for j in range(3)])) + assert (foo(i) == ti.Matrix([[1, 2, 3, 4] if j == i else [0, 0, 0, 0] + for j in range(3)])).all() + + +@test_utils.test(dynamic_index=True) +def test_matrix_slice_write_dynamic_index(): + _test_matrix_slice_write_dynamic_index() + + +@test_utils.test(real_matrix=True, + real_matrix_scalarize=True, + dynamic_index=True) +def test_matrix_slice_write_dynamic_index_real_matrix_scalarize(): + _test_matrix_slice_write_dynamic_index()