Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Remove tests with real_matrix=True and real_matrix_scalarize=True #6873

Merged
merged 3 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions tests/python/test_ad_gdar_diffmpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tests import test_utils


def _test_gdar_mpm():
@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
def test_gdar_mpm():
real = ti.f32

dim = 2
Expand Down Expand Up @@ -182,17 +183,3 @@ def substep(s):
learning_rate = 10
init_v[None][0] -= learning_rate * grad[0]
init_v[None][1] -= learning_rate * grad[1]


@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
def test_gdar_mpm():
_test_gdar_mpm()


@test_utils.test(require=ti.extension.assertion,
debug=True,
exclude=[ti.cc],
real_matrix=True,
real_matrix_scalarize=True)
def test_gdar_mpm_real_matrix_scalarize():
_test_gdar_mpm()
75 changes: 12 additions & 63 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def func():
assert x[i, j, k] == 0


def _test_grouped_ndrange_for():
@test_utils.test(print_preprocessed_ir=True)
def test_grouped_ndrange_for():
x = ti.field(ti.i32, shape=(6, 6, 6))
y = ti.field(ti.i32, shape=(6, 6, 6))

Expand All @@ -457,18 +458,6 @@ def func():
assert x[i, j, k] == y[i, j, k]


@test_utils.test(print_preprocessed_ir=True)
def test_grouped_ndrange_for():
_test_grouped_ndrange_for()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_grouped_ndrange_for_matrix_scalarize():
_test_grouped_ndrange_for()


@test_utils.test(print_preprocessed_ir=True)
def test_static_for_break():
n = 10
Expand All @@ -493,7 +482,8 @@ def foo(a: ti.template()):
assert a[i] == 0


def _test_static_grouped_for_break():
@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
n = 4

@ti.kernel
Expand All @@ -517,18 +507,6 @@ def foo(a: ti.template()):
assert a[i, j] == 0


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_break_matrix_scalarize():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True)
def test_static_for_continue():
n = 10
Expand All @@ -551,7 +529,8 @@ def foo(a: ti.template()):
assert a[i] == 3


def _test_static_grouped_for_continue():
@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
n = 4

@ti.kernel
Expand All @@ -573,18 +552,6 @@ def foo(a: ti.template()):
assert a[i, j] == 3


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_continue_matrix_scalarize():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True)
def test_for_break():
n = 4
Expand Down Expand Up @@ -885,8 +852,8 @@ def foo(x: ti.template()) -> ti.i32:
foo(2)


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_single_listcomp_matrix_scalarize():
@test_utils.test()
def test_single_listcomp():
@ti.func
def identity(dt, n: ti.template()):
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
Expand All @@ -904,7 +871,8 @@ def foo(n: ti.template()) -> ti.i32:
assert foo(5) == 1


def _test_listcomp():
@test_utils.test()
def test_listcomp():
@ti.func
def identity(dt, n: ti.template()):
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
Expand All @@ -923,16 +891,6 @@ def foo(n: ti.template()) -> ti.i32:
assert foo(5) == 1 + 4 + 9 + 16


@test_utils.test()
def test_listcomp():
_test_listcomp()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_listcomp_matrix_scalarize():
_test_listcomp()


@test_utils.test()
def test_dictcomp():
@ti.kernel
Expand Down Expand Up @@ -1061,7 +1019,8 @@ def foo() -> ti.i32:
assert foo() == 123


def _test_grouped_static_for_cast():
@test_utils.test()
def test_grouped_static_for_cast():
@ti.kernel
def foo() -> ti.f32:
ret = 0.
Expand All @@ -1071,13 +1030,3 @@ def foo() -> ti.f32:
return ret

assert foo() == test_utils.approx(10)


@test_utils.test()
def test_grouped_static_for_cast():
_test_grouped_static_for_cast()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_grouped_static_for_cast_matrix_scalarize():
_test_grouped_static_for_cast()
120 changes: 16 additions & 104 deletions tests/python/test_bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,142 +62,54 @@ def _test_bls_stencil(*args, **kwargs):
bls_test_template(*args, **kwargs)


def _test_gather_1d_trivial():
@test_utils.test(require=ti.extension.bls)
def test_gather_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ))


def _test_gather_1d():
@test_utils.test(require=ti.extension.bls)
def test_gather_1d():
# y[i] = x[i - 1] + x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, )))


def _test_gather_2d():
@test_utils.test(require=ti.extension.bls)
def test_gather_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil)


def _test_gather_2d_nonsquare():
@test_utils.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil)


def _test_gather_3d():
@test_utils.test(require=ti.extension.bls)
def test_gather_3d():
stencil = [(-1, -1, -1), (2, 0, 1)]
_test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil)


def _test_scatter_1d_trivial():
@test_utils.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True)


def _test_scatter_1d():
@test_utils.test(require=ti.extension.bls)
def test_scatter_1d():
_test_bls_stencil(1, 128, bs=32, stencil=(
(1, ),
(0, ),
), scatter=True)


def _test_scatter_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@test_utils.test(require=ti.extension.bls)
def test_gather_1d_trivial():
_test_gather_1d_trivial()


@test_utils.test(require=ti.extension.bls)
def test_gather_1d():
_test_gather_1d()


@test_utils.test(require=ti.extension.bls)
def test_gather_2d():
_test_gather_2d()


@test_utils.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
_test_gather_2d_nonsquare()


@test_utils.test(require=ti.extension.bls)
def test_gather_3d():
_test_gather_3d()


@test_utils.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
_test_scatter_1d_trivial()


@test_utils.test(require=ti.extension.bls)
def test_scatter_1d():
_test_scatter_1d()


@test_utils.test(require=ti.extension.bls)
def test_scatter_2d():
_test_scatter_2d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_1d_trivial_matrix_scalarize():
_test_gather_1d_trivial()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_1d_matrix_scalarize():
_test_gather_1d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_2d_matrix_scalarize():
_test_gather_2d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_2d_nonsquare_matrix_scalarize():
_test_gather_2d_nonsquare()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gather_3d_matrix_scalarize():
_test_gather_3d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_1d_trivial_matrix_scalarize():
_test_scatter_1d_trivial()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_1d_matrix_scalarize():
_test_scatter_1d()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scatter_2d_matrix_scalarize():
_test_scatter_2d()
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@test_utils.test(require=ti.extension.bls)
Expand Down
Loading