Skip to content

Commit

Permalink
[refactor] Remove legacy helper functions for testing (#3874)
Browse files Browse the repository at this point in the history
* Remove legacy helper functions for testing

* Remove must_throw()

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Dec 24, 2021
1 parent 76ffb1b commit 0a7a7c4
Show file tree
Hide file tree
Showing 19 changed files with 130 additions and 305 deletions.
4 changes: 2 additions & 2 deletions benchmarks/fill_sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import taichi as ti


@ti.archs_support_sparse
@ti.test(require=ti.extension.sparse)
def benchmark_nested_struct():
a = ti.field(dtype=ti.f32)
N = 512
Expand All @@ -18,7 +18,7 @@ def fill():
return ti.benchmark(fill)


@ti.archs_support_sparse
@ti.test(require=ti.extension.sparse)
def benchmark_nested_struct_fill_and_clear():
a = ti.field(dtype=ti.f32)
N = 512
Expand Down
186 changes: 0 additions & 186 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from contextlib import contextmanager
from copy import deepcopy as _deepcopy
from urllib import request
from urllib.error import HTTPError

import taichi.lang.linalg_impl
import taichi.lang.meta
Expand Down Expand Up @@ -58,7 +57,6 @@
from taichi.profiler.kernelmetrics import (CuptiMetric, default_cupti_metrics,
get_predefined_cupti_metrics)
from taichi.snode.fields_builder import FieldsBuilder
from taichi.tools.util import get_traceback
from taichi.types.annotations import any_arr, ext_arr, template
from taichi.types.primitive_types import (f16, f32, f64, i32, i64,
integer_types, u32, u64)
Expand Down Expand Up @@ -1111,192 +1109,8 @@ def adaptive_arch_select(arch, enable_fallback, use_gles):
return cpu


class _ArchCheckers:
def __init__(self):
self._checkers = []

def register(self, c):
self._checkers.append(c)

def __call__(self, arch):
assert isinstance(arch, _ti_core.Arch)
return all([c(arch) for c in self._checkers])


_tests_arch_checkers_argname = '_tests_arch_checkers'


def _get_or_make_arch_checkers(kwargs):
_k = _tests_arch_checkers_argname
if _k not in kwargs:
kwargs[_k] = _ArchCheckers()
return kwargs[_k]


# test with all archs
def all_archs_with(**kwargs):
kwargs = _deepcopy(kwargs)

def decorator(test):
# @pytest.mark.parametrize decorator only knows about regular function args,
# without *args or **kwargs. By decorating with @functools.wraps, the
# signature of |test| is preserved, so that @ti.all_archs can be used after
# the parametrization decorator.
#
# Full discussion: https://github.com/pytest-dev/pytest/issues/6810
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
can_run_on = test_kwargs.pop(_tests_arch_checkers_argname,
_ArchCheckers())
# Filter away archs that don't support 64-bit data.
fp = kwargs.get('default_fp', ti.f32)
ip = kwargs.get('default_ip', ti.i32)
if fp == ti.f64 or ip == ti.i64:
can_run_on.register(lambda arch: is_extension_supported(
arch, extension.data64))

for arch in ti._testing.expected_archs():
if can_run_on(arch):
print(f'Running test on arch={arch}')
ti.init(arch=arch, **kwargs)
test(*test_args, **test_kwargs)
else:
print(f'Skipped test on arch={arch}')

return wrapped

return decorator


# test with all archs
def all_archs(test):
return all_archs_with()(test)


# Exclude the given archs when running the tests
#
# Example usage:
#
# @ti.archs_excluding(ti.cuda, ti.metal)
# def test_xx():
# ...
#
# @ti.archs_excluding(ti.cuda, default_fp=ti.f64)
# def test_yy():
# ...
def archs_excluding(*excluded_archs, **kwargs):
# |kwargs| will be passed to all_archs_with(**kwargs)
assert all([isinstance(a, _ti_core.Arch) for a in excluded_archs])
excluded_archs = set(excluded_archs)

def decorator(test):
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
def checker(arch):
return arch not in excluded_archs

_get_or_make_arch_checkers(test_kwargs).register(checker)
return all_archs_with(**kwargs)(test)(*test_args, **test_kwargs)

return wrapped

return decorator


# Specifies the extension features the archs are required to support in order
# to run the test.
#
# Example usage:
#
# @ti.require(ti.extension.data64)
# @ti.all_archs_with(default_fp=ti.f64)
# def test_xx():
# ...
def require(*exts):
# Because this decorator injects an arch checker, its usage must be followed
# with all_archs_with(), either directly or indirectly.
assert all([isinstance(e, _ti_core.Extension) for e in exts])

def decorator(test):
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
def checker(arch):
return all([is_extension_supported(arch, e) for e in exts])

_get_or_make_arch_checkers(test_kwargs).register(checker)
test(*test_args, **test_kwargs)

return wrapped

return decorator


def archs_support_sparse(test, **kwargs):
wrapped = all_archs_with(**kwargs)(test)
return require(extension.sparse)(wrapped)


def torch_test(_func):
if ti.has_pytorch():
# OpenGL somehow crashes torch test without a reason, unforturnately
return ti.test(exclude=[opengl])(_func)
return lambda: None


def get_host_arch_list():
return [_ti_core.host_arch()]


# test with host arch only
def host_arch_only(_func):
@functools.wraps(_func)
def test(*args, **kwargs):
archs = [_ti_core.host_arch()]
for arch in archs:
ti.init(arch=arch)
_func(*args, **kwargs)

return test


def archs_with(archs, **init_kwags):
"""
Run the test on the given archs with the given init args.
Args:
archs: a list of Taichi archs
init_kwargs: kwargs passed to ti.init()
"""
def decorator(test):
@functools.wraps(test)
def wrapped(*test_args, **test_kwargs):
for arch in archs:
ti.init(arch=arch, **init_kwags)
test(*test_args, **test_kwargs)

return wrapped

return decorator


def must_throw(ex):
def decorator(_func):
def func__(*args, **kwargs):
finishes = False
try:
_func(*args, **kwargs)
finishes = True
except ex:
# throws. test passed
pass
except Exception as err_actual:
assert False, f'Exception {str(type(err_actual))} instead of {str(ex)} thrown'
if finishes:
assert False, f'Test successfully finished instead of throwing {str(ex)}'

return func__

return decorator


__all__ = [s for s in dir() if not s.startswith('_')]
4 changes: 2 additions & 2 deletions tests/python/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def func():


@ti.test(arch=ti.get_host_arch_list())
@ti.must_throw(ti.TaichiCompilationError)
def test_static_assert_message():
x = 3

@ti.kernel
def func():
ti.static_assert(x == 4, "Oh, no!")

func()
with pytest.raises(ti.TaichiCompilationError):
func()


@ti.test(arch=ti.get_host_arch_list())
Expand Down
20 changes: 12 additions & 8 deletions tests/python/test_customized_grad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import taichi as ti


Expand Down Expand Up @@ -169,7 +171,6 @@ def backward(self, mul):


@ti.test()
@ti.must_throw(RuntimeError)
def test_decorated_primal_is_taichi_kernel():
x = ti.field(ti.f32)
total = ti.field(ti.f32)
Expand All @@ -185,16 +186,17 @@ def func(mul: ti.f32):
for i in range(n):
ti.atomic_add(total[None], x[i] * mul)

@ti.ad.grad_for(func)
def backward(mul):
func.grad(mul)
with pytest.raises(RuntimeError):

@ti.ad.grad_for(func)
def backward(mul):
func.grad(mul)

with ti.Tape(loss=total):
func(4)


@ti.test()
@ti.must_throw(RuntimeError)
def test_decorated_primal_missing_decorator():
x = ti.field(ti.f32)
total = ti.field(ti.f32)
Expand All @@ -214,9 +216,11 @@ def foward(mul):
func(mul)
func(mul)

@ti.ad.grad_for(func)
def backward(mul):
func.grad(mul)
with pytest.raises(RuntimeError):

@ti.ad.grad_for(func)
def backward(mul):
func.grad(mul)

with ti.Tape(loss=total):
func(4)
2 changes: 1 addition & 1 deletion tests/python/test_for_break.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def func():
assert x[i, j] == 100 * i + j


@ti.archs_excluding(ti.vulkan)
@ti.test(exclude=ti.vulkan)
def test_for_break3():
x = ti.field(ti.i32)
N, M = 8, 8
Expand Down
Loading

0 comments on commit 0a7a7c4

Please sign in to comment.