Skip to content

Commit

Permalink
TST: test multiple backends, all dtypes (#29)
Browse files Browse the repository at this point in the history
* TST: test multiple backends, all dtypes
* TST: loosen tolerances
  • Loading branch information
mdhaber authored Dec 1, 2024
1 parent fd84308 commit c9c8388
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 32 deletions.
112 changes: 80 additions & 32 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import operator
import functools
import itertools
import operator

import array_api_strict as strict
import numpy as np
import pytest
import array_api_strict as strict

import marray

xps = [np, strict]
dtypes_boolean = ['bool']
dtypes_integral = ['uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16', 'int32', 'int64']
dtypes_uint = ['uint8', 'uint16', 'uint32', 'uint64', ]
dtypes_int = ['int8', 'int16', 'int32', 'int64']
dtypes_real = ['float32', 'float64']
dtypes_complex = ['complex64', 'complex128']
dtypes_integral = dtypes_uint + dtypes_int
dtypes_numeric = dtypes_integral + dtypes_real + dtypes_complex
dtypes_all = dtypes_boolean + dtypes_integral + dtypes_real + dtypes_complex


Expand All @@ -33,7 +39,8 @@ def get_arrays(n_arrays, *, ndim=(1, 4), dtype='float64', xp=np, seed=None):
if dtype == 'bool':
data = data > 0
else:
data = data.astype(dtype)
# multiply by 10 to get some variety in integers
data = (data*10).astype(dtype)

datas.append(data)
# for now, make masks same shape as array
Expand All @@ -51,7 +58,11 @@ def get_arrays(n_arrays, *, ndim=(1, 4), dtype='float64', xp=np, seed=None):
return marrays, masked_arrays, entropy


def assert_comparison(res, ref, seed, comparison, **kwargs):
def assert_comparison(res, ref, seed, xp, comparison, **kwargs):
if xp is not None:
array_type = type(xp.asarray(1.))
assert isinstance(res.data, array_type)
assert isinstance(res.mask, array_type)
ref_mask = np.broadcast_to(ref.mask, ref.data.shape)
try:
comparison(res.data[~res.mask], ref.data[~ref_mask], strict=True, **kwargs)
Expand All @@ -60,12 +71,37 @@ def assert_comparison(res, ref, seed, comparison, **kwargs):
raise AssertionError(seed) from e


def assert_equal(res, ref, seed=None, **kwargs):
return assert_comparison(res, ref, seed, np.testing.assert_equal, **kwargs)
def assert_equal(res, ref, seed, xp=None, **kwargs):
return assert_comparison(res, ref, seed, xp, np.testing.assert_equal, **kwargs)


def assert_allclose(res, ref, seed, **kwargs):
return assert_comparison(res, ref, seed, np.testing.assert_allclose, **kwargs)
def assert_allclose(res, ref, seed, xp=None, **kwargs):
return assert_comparison(res, ref, seed, xp, np.testing.assert_allclose, **kwargs)


def pass_exceptions(allowed=[]):
def outer(f):
@functools.wraps(f)
def inner(*args, seed=None, **kwargs):
try:
return f(*args, seed=seed, **kwargs)
except (ValueError, TypeError) as e:
for message in allowed:
if str(e).startswith(message):
return
else:
raise AssertionError(seed) from e
return inner
return outer


def get_rtol(dtype, xp):
if isinstance(dtype, str):
dtype = getattr(xp, dtype)
if xp.isdtype(dtype, ('real floating', 'complex floating')):
return xp.finfo(dtype).eps**0.5
else:
return 0


arithmetic_unary = [lambda x: +x, lambda x: -x, abs]
Expand Down Expand Up @@ -148,41 +184,53 @@ def irshift(x, y): x >>= y
utility_array = ['all', 'any']


@pytest.mark.parametrize("f", arithmetic_unary + arithmetic_methods_unary)
@pytest.mark.parametrize('dtype', dtypes_real)
def test_arithmetic_unary(f, dtype, seed=None):
marrays, masked_arrays, seed = get_arrays(1, seed=seed)
@pytest.mark.parametrize("f", arithmetic_unary[:1] + arithmetic_methods_unary)
@pytest.mark.parametrize('dtype', dtypes_numeric)
@pytest.mark.parametrize('xp', xps)
def test_arithmetic_unary(f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0])
ref = f(masked_arrays[0])
assert_equal(res, ref, seed)
assert_equal(res, ref, seed=seed, xp=xp)


arithetic_binary_exceptions = [
"Integers to negative integer powers are not allowed.",
"Only floating-point dtypes are allowed in __truediv__",
"ufunc 'floor_divide' not supported for the input types",
"ufunc 'remainder' not supported for the input types,",
"Only real numeric dtypes are allowed in __floordiv__",
"Only real numeric dtypes are allowed in __mod__"
]


@pytest.mark.parametrize("f", arithmetic_binary + arithmetic_methods_binary)
def test_arithmetic_binary(f, seed=None):
marrays, masked_arrays, seed = get_arrays(2, seed=seed)
@pytest.mark.parametrize('dtype', dtypes_numeric)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=arithetic_binary_exceptions)
def test_arithmetic_binary(f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0], marrays[1])
ref_data = f(masked_arrays[0].data, masked_arrays[1].data)
ref_mask = masked_arrays[0].mask | masked_arrays[1].mask
ref = np.ma.masked_array(ref_data, mask=ref_mask)
assert_equal(res, ref, seed)
assert_equal(res, ref, seed=seed, xp=xp)


@pytest.mark.parametrize('xp', xps)
@pytest.mark.parametrize("f", array_binary + array_methods_binary)
def test_array_binary(f, seed=None):
marrays, masked_arrays, seed = get_arrays(1, seed=seed)
if marrays[0].ndim < 2:
with pytest.raises(ValueError, match="undefined"):
f(marrays[0], marrays[0].mT)
else:
res = f(marrays[0], marrays[0].mT)

x = masked_arrays[0].data
mask = masked_arrays[0].mask
x[mask] = 0
data = f(x, x.mT)
mask = ~f(~mask, ~mask.mT)
ref = np.ma.masked_array(data, mask=mask)
assert_allclose(res, ref, seed)
@pytest.mark.parametrize('dtype', dtypes_all)
@pass_exceptions(allowed=["Only numeric dtypes are allowed in matmul"])
def test_array_binary(f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(1, ndim=(2, 4), xp=xp, dtype=dtype, seed=seed)
res = f(marrays[0], marrays[0].mT)
x = masked_arrays[0].data
mask = masked_arrays[0].mask
x[mask] = 0
data = f(x, x.mT)
mask = ~f(~mask, ~mask.mT)
ref = np.ma.masked_array(data, mask=mask)
assert_allclose(res, ref, seed=seed, xp=xp, rtol=get_rtol(dtype, xp))


@pytest.mark.parametrize("dtype", dtypes_integral + dtypes_boolean)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ filterwarnings = [
"error:::marray.*",
"ignore:invalid value encountered:RuntimeWarning",
"ignore:divide by zero encountered:RuntimeWarning",
"ignore:overflow encountered:RuntimeWarning",
]

0 comments on commit c9c8388

Please sign in to comment.