Skip to content

Commit

Permalink
Add @ti.parametrize decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye committed Feb 25, 2020
1 parent 328bdbe commit c06e021
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 45 deletions.
42 changes: 42 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,48 @@ def func__(*args, **kwargs):
return decorator


# Poor man's parametrized test decorator
# The usage is identical to how @pytest.mark.parametrize is commonly used. For
# example:
#
# @ti.parametrize('foo', [1, 2, 3])
# @ti.all_archs
# def test_xx(foo):
# ...
#
# @ti.parametrize('foo,bar', [
# (1, 'a'),
# (2, 'b'),
# (3, 'c'),
# ])
# @ti.all_archs
# def test_yy(foo, bar):
# ...
def parametrize(argnames: str, argvalues):
# @pytest.mark.parametrize only works for canonical function args, and doesn't
# support *args or **kwargs. This makes it difficult to play along with other
# decorators like @ti.all_archs. As a result, we implement our own.
argnames = [s.strip() for s in argnames.split(',')]
def iterable(x):
try:
_ = iter(x)
return True
except:
return False

def decorator(test):
def wrapped(*test_args, **test_kwargs):
for vals in argvalues:
if isinstance(vals, str) or not iterable(vals):
vals = (vals, )
kwargs = {k: v for k, v in zip(argnames, vals)}
assert len(kwargs.keys() & test_kwargs.keys()) == 0
kwargs.update(test_kwargs)
test(*test_args, **kwargs)
return wrapped
return decorator


def complex_kernel(func):

def decorated(*args, **kwargs):
Expand Down
74 changes: 29 additions & 45 deletions tests/python/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,18 @@ def func(value: dt):
func(3)
assert x[None] == 3

@pytest.mark.parametrize('dt', _TI_TYPES)
@ti.parametrize('dt', _TI_TYPES)
# Metal backend doesn't support arg type other than 32-bit yet.
@ti.archs_excluding(ti.metal)
def test_type_assign_argument(dt):
# Metal backend doesn't support arg type other than 32-bit yet.
@ti.archs_excluding(ti.metal)
def run():
_test_type_assign_argument(dt)
run()
_test_type_assign_argument(dt)


@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_assign_argument64(dt):
@ti.require(ti.extension.data64)
@ti.all_archs
def run():
_test_type_assign_argument(dt)
run()
_test_type_assign_argument(dt)

def _test_type_operator(dt):
x = ti.var(dt, shape=())
Expand All @@ -50,20 +46,16 @@ def func():
assert add[None] == x[None] + y[None]
assert mul[None] == x[None] * y[None]

@pytest.mark.parametrize('dt', _TI_TYPES)
@ti.parametrize('dt', _TI_TYPES)
@ti.all_archs
def test_type_operator(dt):
@ti.all_archs
def run():
_test_type_operator(dt)
run()
_test_type_operator(dt)

@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_operator64(dt):
@ti.require(ti.extension.data64)
@ti.all_archs
def run():
_test_type_operator(dt)
run()
_test_type_operator(dt)

def _test_type_tensor(dt):
x = ti.var(dt, shape=(3, 2))
Expand All @@ -78,20 +70,16 @@ def func(i: ti.i32, j: ti.i32):
assert x[i, j] == 3


@pytest.mark.parametrize('dt', _TI_TYPES)
@ti.parametrize('dt', _TI_TYPES)
@ti.all_archs
def test_type_tensor(dt):
@ti.all_archs
def run():
_test_type_tensor(dt)
run()
_test_type_tensor(dt)

@pytest.mark.parametrize('dt', _TI_64_TYPES)
@ti.parametrize('dt', _TI_64_TYPES)
@ti.require(ti.extension.data64)
@ti.all_archs
def test_type_tensor64(dt):
@ti.require(ti.extension.data64)
@ti.all_archs
def run():
_test_type_tensor(dt)
run()
_test_type_tensor(dt)

def _test_overflow(dt, n):
a = ti.var(dt, shape=())
Expand All @@ -115,27 +103,23 @@ def func():
else:
assert c[None] == 2 ** n // 3 * 2 # does not overflow

@pytest.mark.parametrize('dt,n', [
@ti.parametrize('dt,n', [
(ti.i8, 8),
(ti.u8, 8),
(ti.i16, 16),
(ti.u16, 16),
(ti.i32, 32),
(ti.u32, 32),
])
@ti.all_archs
def test_overflow(dt, n):
@ti.all_archs
def run():
_test_overflow(dt, n)
run()
_test_overflow(dt, n)

@pytest.mark.parametrize('dt,n', [
@ti.parametrize('dt,n', [
(ti.i64, 64),
(ti.u64, 64),
])
@ti.require(ti.extension.data64)
@ti.all_archs
def test_overflow64(dt, n):
@ti.require(ti.extension.data64)
@ti.all_archs
def run():
_test_overflow(dt, n)
run()
_test_overflow(dt, n)

0 comments on commit c06e021

Please sign in to comment.