Skip to content

Commit

Permalink
[Lang] MatrixField refactor 6/n: Add tests for MatrixField scalarizat…
Browse files Browse the repository at this point in the history
…ion (#6137)

Issue: #5959

### Brief Summary

This PR demonstrates that the following path of #5959 can also run
end-to-end:


![MatrixField_Part6](https://user-images.githubusercontent.com/3251060/191686873-839fd9c2-11fb-4fea-b648-3350a3992779.png)
  • Loading branch information
strongoier authored Sep 22, 2022
1 parent 0fa2eed commit b5072a6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
7 changes: 2 additions & 5 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,8 @@ def build_Assign(ctx, node):
@staticmethod
def build_assign_slice(ctx, node_target, values, is_static_assign):
target = ASTTransformer.build_Subscript(ctx, node_target, get_ref=True)
if current_cfg().real_matrix:
if isinstance(node_target.value.ptr,
any_array.AnyArray) and isinstance(
values, (list, tuple)):
values = make_matrix(values)
if current_cfg().real_matrix and isinstance(values, (list, tuple)):
values = make_matrix(values)

if isinstance(node_target.value.ptr, Matrix):
if isinstance(node_target.value.ptr._impl, _TiScopeMatrixImpl):
Expand Down
75 changes: 48 additions & 27 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,63 +828,84 @@ def foo() -> ti.types.matrix(2, 2, ti.f32):
assert foo() == [[1.0, 2.0], [2.0, 4.0]]


def _test_field_and_ndarray(field, ndarray, func, verify):
@ti.kernel
def kern_field(a: ti.template()):
func(a)

@ti.kernel
def kern_ndarray(a: ti.types.ndarray()):
func(a)

kern_field(field)
verify(field)
kern_ndarray(ndarray)
verify(ndarray)


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
@ti.func
def func(a: ti.template()):
for i in range(5):
a[i] = [[i, i + 1], [i + 2, i + 3]]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)
def verify(x):
assert (x[0] == [[0, 1], [2, 3]]).all()
assert (x[1] == [[1, 2], [3, 4]]).all()
assert (x[2] == [[2, 3], [4, 5]]).all()
assert (x[3] == [[3, 4], [5, 6]]).all()
assert (x[4] == [[4, 5], [6, 7]]).all()

assert (x[0] == [[0, 1], [2, 3]]).all()
assert (x[1] == [[1, 2], [3, 4]]).all()
assert (x[2] == [[2, 3], [4, 5]]).all()
assert (x[3] == [[3, 4], [5, 6]]).all()
assert (x[4] == [[4, 5], [6, 7]]).all()
field = ti.Matrix.field(2, 2, ti.i32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_load_store_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
@ti.func
def func(a: ti.template()):
for i in range(3):
a[i] = [[i, i + 1], [i + 2, i + 3]]

a[3] = a[1]
a[4] = a[2]

x = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
func(x)
def verify(x):
assert (x[3] == [[1, 2], [3, 4]]).all()
assert (x[4] == [[2, 3], [4, 5]]).all()

assert (x[3] == [[1, 2], [3, 4]]).all()
assert (x[4] == [[2, 3], [4, 5]]).all()
field = ti.Matrix.field(2, 2, ti.i32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.i32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_unary_op_scalarize():
@ti.kernel
def func(a: ti.types.ndarray()):
@ti.func
def func(a: ti.template()):
a[0] = [[0, 1], [2, 3]]
a[1] = [[3, 4], [5, 6]]
a[2] = -a[0]
a[3] = ti.exp(a[1])
a[4] = ti.sqrt(a[3])

x = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
func(x)

assert (x[0] == [[0., 1.], [2., 3.]]).all()
assert (x[1] == [[3., 4.], [5., 6.]]).all()
assert (x[2] == [[-0., -1.], [-2., -3.]]).all()
assert (x[3] < [[20.086, 54.60], [148.42, 403.43]]).all()
assert (x[3] > [[20.085, 54.59], [148.41, 403.42]]).all()
assert (x[4] < [[4.49, 7.39], [12.19, 20.09]]).all()
assert (x[4] > [[4.48, 7.38], [12.18, 20.08]]).all()
def verify(x):
assert (x[0] == [[0., 1.], [2., 3.]]).all()
assert (x[1] == [[3., 4.], [5., 6.]]).all()
assert (x[2] == [[-0., -1.], [-2., -3.]]).all()
assert (x[3] < [[20.086, 54.60], [148.42, 403.43]]).all()
assert (x[3] > [[20.085, 54.59], [148.41, 403.42]]).all()
assert (x[4] < [[4.49, 7.39], [12.19, 20.09]]).all()
assert (x[4] > [[4.48, 7.38], [12.18, 20.08]]).all()

field = ti.Matrix.field(2, 2, ti.f32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)

0 comments on commit b5072a6

Please sign in to comment.