Skip to content

Commit

Permalink
[test] Remove tests with real_matrix=True and real_matrix_scalarize=True
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Dec 12, 2022
1 parent f735fe1 commit b96e1e8
Show file tree
Hide file tree
Showing 30 changed files with 175 additions and 1,184 deletions.
17 changes: 2 additions & 15 deletions tests/python/test_ad_gdar_diffmpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tests import test_utils


def _test_gdar_mpm():
@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
def test_gdar_mpm():
real = ti.f32

dim = 2
Expand Down Expand Up @@ -182,17 +183,3 @@ def substep(s):
learning_rate = 10
init_v[None][0] -= learning_rate * grad[0]
init_v[None][1] -= learning_rate * grad[1]


@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
def test_gdar_mpm():
_test_gdar_mpm()


@test_utils.test(require=ti.extension.assertion,
debug=True,
exclude=[ti.cc],
real_matrix=True,
real_matrix_scalarize=True)
def test_gdar_mpm_real_matrix_scalarize():
_test_gdar_mpm()
75 changes: 12 additions & 63 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def func():
assert x[i, j, k] == 0


def _test_grouped_ndrange_for():
@test_utils.test(print_preprocessed_ir=True)
def test_grouped_ndrange_for():
x = ti.field(ti.i32, shape=(6, 6, 6))
y = ti.field(ti.i32, shape=(6, 6, 6))

Expand All @@ -457,18 +458,6 @@ def func():
assert x[i, j, k] == y[i, j, k]


@test_utils.test(print_preprocessed_ir=True)
def test_grouped_ndrange_for():
_test_grouped_ndrange_for()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_grouped_ndrange_for_matrix_scalarize():
_test_grouped_ndrange_for()


@test_utils.test(print_preprocessed_ir=True)
def test_static_for_break():
n = 10
Expand All @@ -493,7 +482,8 @@ def foo(a: ti.template()):
assert a[i] == 0


def _test_static_grouped_for_break():
@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
n = 4

@ti.kernel
Expand All @@ -517,18 +507,6 @@ def foo(a: ti.template()):
assert a[i, j] == 0


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_break_matrix_scalarize():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True)
def test_static_for_continue():
n = 10
Expand All @@ -551,7 +529,8 @@ def foo(a: ti.template()):
assert a[i] == 3


def _test_static_grouped_for_continue():
@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
n = 4

@ti.kernel
Expand All @@ -573,18 +552,6 @@ def foo(a: ti.template()):
assert a[i, j] == 3


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_continue_matrix_scalarize():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True)
def test_for_break():
n = 4
Expand Down Expand Up @@ -885,8 +852,8 @@ def foo(x: ti.template()) -> ti.i32:
foo(2)


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_single_listcomp_matrix_scalarize():
@test_utils.test()
def test_single_listcomp():
@ti.func
def identity(dt, n: ti.template()):
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
Expand All @@ -904,7 +871,8 @@ def foo(n: ti.template()) -> ti.i32:
assert foo(5) == 1


def _test_listcomp():
@test_utils.test()
def test_listcomp():
@ti.func
def identity(dt, n: ti.template()):
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
Expand All @@ -923,16 +891,6 @@ def foo(n: ti.template()) -> ti.i32:
assert foo(5) == 1 + 4 + 9 + 16


@test_utils.test()
def test_listcomp():
_test_listcomp()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_listcomp_matrix_scalarize():
_test_listcomp()


@test_utils.test()
def test_dictcomp():
@ti.kernel
Expand Down Expand Up @@ -1061,7 +1019,8 @@ def foo() -> ti.i32:
assert foo() == 123


def _test_grouped_static_for_cast():
@test_utils.test()
def test_grouped_static_for_cast():
@ti.kernel
def foo() -> ti.f32:
ret = 0.
Expand All @@ -1071,13 +1030,3 @@ def foo() -> ti.f32:
return ret

assert foo() == test_utils.approx(10)


@test_utils.test()
def test_grouped_static_for_cast():
_test_grouped_static_for_cast()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_grouped_static_for_cast_matrix_scalarize():
_test_grouped_static_for_cast()
120 changes: 16 additions & 104 deletions tests/python/test_bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,142 +62,54 @@ def _test_bls_stencil(*args, **kwargs):
bls_test_template(*args, **kwargs)


def _test_gather_1d_trivial():
@test_utils.test(require=ti.extension.bls)
def test_gather_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ))


def _test_gather_1d():
@test_utils.test(require=ti.extension.bls)
def test_gather_1d():
# y[i] = x[i - 1] + x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, )))


def _test_gather_2d():
@test_utils.test(require=ti.extension.bls)
def test_gather_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil)


def _test_gather_2d_nonsquare():
@test_utils.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil)


def _test_gather_3d():
@test_utils.test(require=ti.extension.bls)
def test_gather_3d():
stencil = [(-1, -1, -1), (2, 0, 1)]
_test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil)


def _test_scatter_1d_trivial():
@test_utils.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True)


def _test_scatter_1d():
@test_utils.test(require=ti.extension.bls)
def test_scatter_1d():
_test_bls_stencil(1, 128, bs=32, stencil=(
(1, ),
(0, ),
), scatter=True)


def _test_scatter_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@test_utils.test(require=ti.extension.bls)
def test_gather_1d_trivial():
_test_gather_1d_trivial()


@test_utils.test(require=ti.extension.bls)
def test_gather_1d():
_test_gather_1d()


@test_utils.test(require=ti.extension.bls)
def test_gather_2d():
_test_gather_2d()


@test_utils.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
_test_gather_2d_nonsquare()


@test_utils.test(require=ti.extension.bls)
def test_gather_3d():
_test_gather_3d()


@test_utils.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
_test_scatter_1d_trivial()


@test_utils.test(require=ti.extension.bls)
def test_scatter_1d():
_test_scatter_1d()


@test_utils.test(require=ti.extension.bls)
def test_scatter_2d():
_test_scatter_2d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_1d_trivial_matrix_scalarize():
_test_gather_1d_trivial()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_1d_matrix_scalarize():
_test_gather_1d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_2d_matrix_scalarize():
_test_gather_2d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_2d_nonsquare_matrix_scalarize():
_test_gather_2d_nonsquare()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_3d_matrix_scalarize():
_test_gather_3d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_1d_trivial_matrix_scalarize():
_test_scatter_1d_trivial()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_1d_matrix_scalarize():
_test_scatter_1d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_2d_matrix_scalarize():
_test_scatter_2d()
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@test_utils.test(require=ti.extension.bls)
Expand Down
Loading

0 comments on commit b96e1e8

Please sign in to comment.