diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 15da341ca4fb3d..032710df4e8c1d 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -3,6 +3,7 @@ from .transformer import TaichiSyntaxError from .ndrange import ndrange, GroupedNDRange from copy import deepcopy as _deepcopy +import functools import os core = taichi_lang_core @@ -208,6 +209,13 @@ def _get_or_make_arch_checkers(kwargs): def all_archs_with(**kwargs): kwargs = _deepcopy(kwargs) def decorator(test): + # @pytest.mark.parametrize decorator only knows about regular function args, + # without *args or **kwargs. By decorating with @functools.wraps, the + # signature of |test| is preserved, so that @ti.all_archs can be used after + # the parametrization decorator. + # + # Full discussion: https://github.com/pytest-dev/pytest/issues/6810 + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): import taichi as ti can_run_on = test_kwargs.pop( @@ -252,6 +260,7 @@ def archs_excluding(*excluded_archs, **kwargs): excluded_archs = set(excluded_archs) def decorator(test): + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): def checker(arch): return arch not in excluded_archs _get_or_make_arch_checkers(test_kwargs).register(checker) @@ -275,6 +284,7 @@ def require(*exts): assert all([isinstance(e, core.Extension) for e in exts]) def decorator(test): + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): def checker(arch): return all([is_supported(arch, e) for e in exts]) _get_or_make_arch_checkers(test_kwargs).register(checker) @@ -330,48 +340,6 @@ def func__(*args, **kwargs): return decorator -# Poor man's parametrized test decorator -# The usage is identical to how @pytest.mark.parametrize is commonly used. For -# example: -# -# @ti.parametrize('foo', [1, 2, 3]) -# @ti.all_archs -# def test_xx(foo): -# ... -# -# @ti.parametrize('foo,bar', [ -# (1, 'a'), -# (2, 'b'), -# (3, 'c'), -# ]) -# @ti.all_archs -# def test_yy(foo, bar): -# ... -def parametrize(argnames: str, argvalues): - # @pytest.mark.parametrize only works for canonical function args, and doesn't - # support *args or **kwargs. This makes it difficult to play along with other - # decorators like @ti.all_archs. As a result, we implement our own. - argnames = [s.strip() for s in argnames.split(',')] - def iterable(x): - try: - _ = iter(x) - return True - except: - return False - - def decorator(test): - def wrapped(*test_args, **test_kwargs): - for vals in argvalues: - if isinstance(vals, str) or not iterable(vals): - vals = (vals, ) - kwargs = {k: v for k, v in zip(argnames, vals)} - assert len(kwargs.keys() & test_kwargs.keys()) == 0 - kwargs.update(test_kwargs) - test(*test_args, **kwargs) - return wrapped - return decorator - - def complex_kernel(func): def decorated(*args, **kwargs): diff --git a/tests/python/test_types.py b/tests/python/test_types.py index dbe4affc4c5cb9..62b4a5eb9b464b 100644 --- a/tests/python/test_types.py +++ b/tests/python/test_types.py @@ -14,14 +14,14 @@ def func(value: dt): func(3) assert x[None] == 3 -@ti.parametrize('dt', _TI_TYPES) +@pytest.mark.parametrize('dt', _TI_TYPES) # Metal backend doesn't support arg type other than 32-bit yet. @ti.archs_excluding(ti.metal) def test_type_assign_argument(dt): _test_type_assign_argument(dt) -@ti.parametrize('dt', _TI_64_TYPES) +@pytest.mark.parametrize('dt', _TI_64_TYPES) @ti.require(ti.extension.data64) @ti.all_archs def test_type_assign_argument64(dt): @@ -46,12 +46,12 @@ def func(): assert add[None] == x[None] + y[None] assert mul[None] == x[None] * y[None] -@ti.parametrize('dt', _TI_TYPES) +@pytest.mark.parametrize('dt', _TI_TYPES) @ti.all_archs def test_type_operator(dt): _test_type_operator(dt) -@ti.parametrize('dt', _TI_64_TYPES) +@pytest.mark.parametrize('dt', _TI_64_TYPES) @ti.require(ti.extension.data64) @ti.all_archs def test_type_operator64(dt): @@ -70,12 +70,12 @@ def func(i: ti.i32, j: ti.i32): assert x[i, j] == 3 -@ti.parametrize('dt', _TI_TYPES) +@pytest.mark.parametrize('dt', _TI_TYPES) @ti.all_archs def test_type_tensor(dt): _test_type_tensor(dt) -@ti.parametrize('dt', _TI_64_TYPES) +@pytest.mark.parametrize('dt', _TI_64_TYPES) @ti.require(ti.extension.data64) @ti.all_archs def test_type_tensor64(dt): @@ -103,7 +103,7 @@ def func(): else: assert c[None] == 2 ** n // 3 * 2 # does not overflow -@ti.parametrize('dt,n', [ +@pytest.mark.parametrize('dt,n', [ (ti.i8, 8), (ti.u8, 8), (ti.i16, 16), @@ -115,7 +115,7 @@ def func(): def test_overflow(dt, n): _test_overflow(dt, n) -@ti.parametrize('dt,n', [ +@pytest.mark.parametrize('dt,n', [ (ti.i64, 64), (ti.u64, 64), ])