From b5072a682408ed461373bfe82c8a5f923d0361e4 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 22 Sep 2022 21:41:26 +0800 Subject: [PATCH] [Lang] MatrixField refactor 6/n: Add tests for MatrixField scalarization (#6137) Issue: #5959 ### Brief Summary This PR demonstrates that the following path of #5959 can also run end-to-end: ![MatrixField_Part6](https://user-images.githubusercontent.com/3251060/191686873-839fd9c2-11fb-4fea-b648-3350a3992779.png) --- python/taichi/lang/ast/ast_transformer.py | 7 +-- tests/python/test_matrix.py | 75 +++++++++++++++-------- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 10be365e8e908..d8b28436321ec 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -115,11 +115,8 @@ def build_Assign(ctx, node): @staticmethod def build_assign_slice(ctx, node_target, values, is_static_assign): target = ASTTransformer.build_Subscript(ctx, node_target, get_ref=True) - if current_cfg().real_matrix: - if isinstance(node_target.value.ptr, - any_array.AnyArray) and isinstance( - values, (list, tuple)): - values = make_matrix(values) + if current_cfg().real_matrix and isinstance(values, (list, tuple)): + values = make_matrix(values) if isinstance(node_target.value.ptr, Matrix): if isinstance(node_target.value.ptr._impl, _TiScopeMatrixImpl): diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 7108cc1208ef0..b00034344275d 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -828,63 +828,84 @@ def foo() -> ti.types.matrix(2, 2, ti.f32): assert foo() == [[1.0, 2.0], [2.0, 4.0]] +def _test_field_and_ndarray(field, ndarray, func, verify): + @ti.kernel + def kern_field(a: ti.template()): + func(a) + + @ti.kernel + def kern_ndarray(a: ti.types.ndarray()): + func(a) + + kern_field(field) + verify(field) + kern_ndarray(ndarray) + verify(ndarray) + + @test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True, real_matrix_scalarize=True) def test_store_scalarize(): - @ti.kernel - def func(a: ti.types.ndarray()): + @ti.func + def func(a: ti.template()): for i in range(5): a[i] = [[i, i + 1], [i + 2, i + 3]] - x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) - func(x) + def verify(x): + assert (x[0] == [[0, 1], [2, 3]]).all() + assert (x[1] == [[1, 2], [3, 4]]).all() + assert (x[2] == [[2, 3], [4, 5]]).all() + assert (x[3] == [[3, 4], [5, 6]]).all() + assert (x[4] == [[4, 5], [6, 7]]).all() - assert (x[0] == [[0, 1], [2, 3]]).all() - assert (x[1] == [[1, 2], [3, 4]]).all() - assert (x[2] == [[2, 3], [4, 5]]).all() - assert (x[3] == [[3, 4], [5, 6]]).all() - assert (x[4] == [[4, 5], [6, 7]]).all() + field = ti.Matrix.field(2, 2, ti.i32, shape=5) + ndarray = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) + _test_field_and_ndarray(field, ndarray, func, verify) @test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True, real_matrix_scalarize=True) def test_load_store_scalarize(): - @ti.kernel - def func(a: ti.types.ndarray()): + @ti.func + def func(a: ti.template()): for i in range(3): a[i] = [[i, i + 1], [i + 2, i + 3]] a[3] = a[1] a[4] = a[2] - x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) - func(x) + def verify(x): + assert (x[3] == [[1, 2], [3, 4]]).all() + assert (x[4] == [[2, 3], [4, 5]]).all() - assert (x[3] == [[1, 2], [3, 4]]).all() - assert (x[4] == [[2, 3], [4, 5]]).all() + field = ti.Matrix.field(2, 2, ti.i32, shape=5) + ndarray = ti.Matrix.ndarray(2, 2, ti.i32, shape=5) + _test_field_and_ndarray(field, ndarray, func, verify) @test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True, real_matrix_scalarize=True) def test_unary_op_scalarize(): - @ti.kernel - def func(a: ti.types.ndarray()): + @ti.func + def func(a: ti.template()): a[0] = [[0, 1], [2, 3]] a[1] = [[3, 4], [5, 6]] a[2] = -a[0] a[3] = ti.exp(a[1]) a[4] = ti.sqrt(a[3]) - x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) - func(x) - - assert (x[0] == [[0., 1.], [2., 3.]]).all() - assert (x[1] == [[3., 4.], [5., 6.]]).all() - assert (x[2] == [[-0., -1.], [-2., -3.]]).all() - assert (x[3] < [[20.086, 54.60], [148.42, 403.43]]).all() - assert (x[3] > [[20.085, 54.59], [148.41, 403.42]]).all() - assert (x[4] < [[4.49, 7.39], [12.19, 20.09]]).all() - assert (x[4] > [[4.48, 7.38], [12.18, 20.08]]).all() + def verify(x): + assert (x[0] == [[0., 1.], [2., 3.]]).all() + assert (x[1] == [[3., 4.], [5., 6.]]).all() + assert (x[2] == [[-0., -1.], [-2., -3.]]).all() + assert (x[3] < [[20.086, 54.60], [148.42, 403.43]]).all() + assert (x[3] > [[20.085, 54.59], [148.41, 403.42]]).all() + assert (x[4] < [[4.49, 7.39], [12.19, 20.09]]).all() + assert (x[4] > [[4.48, 7.38], [12.18, 20.08]]).all() + + field = ti.Matrix.field(2, 2, ti.f32, shape=5) + ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5) + _test_field_and_ndarray(field, ndarray, func, verify)