Skip to content

Commit

Permalink
Use @functools.wrap, thanks to @archibate!
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Feb 25, 2020
1 parent c06e021 commit 369c4ee
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 50 deletions.
52 changes: 10 additions & 42 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .transformer import TaichiSyntaxError
from .ndrange import ndrange, GroupedNDRange
from copy import deepcopy as _deepcopy
import functools
import os

core = taichi_lang_core
Expand Down Expand Up @@ -208,6 +209,13 @@ def _get_or_make_arch_checkers(kwargs):
def all_archs_with(**kwargs):
kwargs = _deepcopy(kwargs)
def decorator(test):
# @pytest.mark.parametrize decorator only knows about regular function args,
# without *args or **kwargs. By decorating with @functools.wraps, the
# signature of |test| is preserved, so that @ti.all_archs can be used after
# the parametrization decorator.
#
# Full discussion: https://github.com/pytest-dev/pytest/issues/6810
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
import taichi as ti
can_run_on = test_kwargs.pop(
Expand Down Expand Up @@ -252,6 +260,7 @@ def archs_excluding(*excluded_archs, **kwargs):
excluded_archs = set(excluded_archs)

def decorator(test):
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
def checker(arch): return arch not in excluded_archs
_get_or_make_arch_checkers(test_kwargs).register(checker)
Expand All @@ -275,6 +284,7 @@ def require(*exts):
assert all([isinstance(e, core.Extension) for e in exts])

def decorator(test):
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
def checker(arch): return all([is_supported(arch, e) for e in exts])
_get_or_make_arch_checkers(test_kwargs).register(checker)
Expand Down Expand Up @@ -330,48 +340,6 @@ def func__(*args, **kwargs):
return decorator


# Poor man's parametrized test decorator
# The usage is identical to how @pytest.mark.parametrize is commonly used. For
# example:
#
# @ti.parametrize('foo', [1, 2, 3])
# @ti.all_archs
# def test_xx(foo):
# ...
#
# @ti.parametrize('foo,bar', [
# (1, 'a'),
# (2, 'b'),
# (3, 'c'),
# ])
# @ti.all_archs
# def test_yy(foo, bar):
# ...
def parametrize(argnames: str, argvalues):
# @pytest.mark.parametrize only works for canonical function args, and doesn't
# support *args or **kwargs. This makes it difficult to play along with other
# decorators like @ti.all_archs. As a result, we implement our own.
argnames = [s.strip() for s in argnames.split(',')]
def iterable(x):
try:
_ = iter(x)
return True
except:
return False

def decorator(test):
def wrapped(*test_args, **test_kwargs):
for vals in argvalues:
if isinstance(vals, str) or not iterable(vals):
vals = (vals, )
kwargs = {k: v for k, v in zip(argnames, vals)}
assert len(kwargs.keys() & test_kwargs.keys()) == 0
kwargs.update(test_kwargs)
test(*test_args, **kwargs)
return wrapped
return decorator


def complex_kernel(func):

def decorated(*args, **kwargs):
Expand Down
16 changes: 8 additions & 8 deletions tests/python/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def func(value: dt):
func(3)
assert x[None] == 3

@ti.parametrize('dt', _TI_TYPES)
@pytest.mark.parametrize('dt', _TI_TYPES)
# Metal backend doesn't support arg type other than 32-bit yet.
@ti.archs_excluding(ti.metal)
def test_type_assign_argument(dt):
_test_type_assign_argument(dt)


@ti.parametrize('dt', _TI_64_TYPES)
@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_assign_argument64(dt):
Expand All @@ -46,12 +46,12 @@ def func():
assert add[None] == x[None] + y[None]
assert mul[None] == x[None] * y[None]

@ti.parametrize('dt', _TI_TYPES)
@pytest.mark.parametrize('dt', _TI_TYPES)
@ti.all_archs
def test_type_operator(dt):
_test_type_operator(dt)

@ti.parametrize('dt', _TI_64_TYPES)
@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_operator64(dt):
Expand All @@ -70,12 +70,12 @@ def func(i: ti.i32, j: ti.i32):
assert x[i, j] == 3


@ti.parametrize('dt', _TI_TYPES)
@pytest.mark.parametrize('dt', _TI_TYPES)
@ti.all_archs
def test_type_tensor(dt):
_test_type_tensor(dt)

@ti.parametrize('dt', _TI_64_TYPES)
@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_tensor64(dt):
Expand Down Expand Up @@ -103,7 +103,7 @@ def func():
else:
assert c[None] == 2 ** n // 3 * 2 # does not overflow

@ti.parametrize('dt,n', [
@pytest.mark.parametrize('dt,n', [
(ti.i8, 8),
(ti.u8, 8),
(ti.i16, 16),
Expand All @@ -115,7 +115,7 @@ def func():
def test_overflow(dt, n):
_test_overflow(dt, n)

@ti.parametrize('dt,n', [
@pytest.mark.parametrize('dt,n', [
(ti.i64, 64),
(ti.u64, 64),
])
Expand Down

0 comments on commit 369c4ee

Please sign in to comment.