Skip to content

Commit

Permalink
TST: test/fix array creation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Nov 24, 2024
1 parent 44d46ed commit f8afbb1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
19 changes: 11 additions & 8 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,19 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):
raise NotImplementedError()

data = getattr(obj, 'data', obj)
data = xp.asarray(data, dtype=dtype, device=device, copy=copy)

mask = (getattr(obj, 'mask', xp.full(data.shape, False))
if mask is None else mask)

data = xp.asarray(data, dtype=dtype, device=device, copy=copy)
mask = xp.asarray(mask, dtype=dtype, device=device, copy=copy)
return MaskedArray(data, mask=mask)
mod.asarray = asarray

creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack',
'full', 'full_like', 'linspace', 'meshgrid', 'ones',
'ones_like', 'tril', 'triu', 'zeros', 'zeros_like']
'full', 'full_like', 'linspace', 'ones', 'ones_like',
'zeros', 'zeros_like']
# handled with array manipulation functions
creation_manip_functions = ['tril', 'triu', 'meshgrid']
for name in creation_functions:
def fun(*args, name=name, **kwargs):
data = getattr(xp, name)(*args, **kwargs)
Expand Down Expand Up @@ -308,8 +310,8 @@ def linalg_fun(x1, x2, /, **kwargs):
mod.matrix_transpose = lambda x: x.mT

## Manipulation Functions ##
first_arg_arrays = {'broadcast_arrays', 'concat', 'stack'}
output_arrays = {'broadcast_arrays', 'unstack'}
first_arg_arrays = {'broadcast_arrays', 'concat', 'stack', 'meshgrid'}
output_arrays = {'broadcast_arrays', 'unstack', 'meshgrid'}

def get_manip_fun(name):
def manip_fun(x, *args, **kwargs):
Expand All @@ -322,7 +324,7 @@ def manip_fun(x, *args, **kwargs):

fun = getattr(xp, name)

if name == 'broadcast_arrays':
if name in {'broadcast_arrays', 'meshgrid'}:
res = fun(*data, *args, **kwargs)
mask = fun(*mask, *args, **kwargs)
else:
Expand All @@ -337,9 +339,10 @@ def manip_fun(x, *args, **kwargs):
manip_names = ['broadcast_arrays', 'broadcast_to', 'concat', 'expand_dims',
'flip', 'moveaxis', 'permute_dims', 'repeat', 'reshape',
'roll', 'squeeze', 'stack', 'tile', 'unstack']
for name in manip_names:
for name in manip_names + creation_manip_functions:
setattr(mod, name, get_manip_fun(name))
mod.broadcast_arrays = lambda *arrays: get_manip_fun('broadcast_arrays')(arrays)
mod.meshgrid = lambda *arrays, **kwargs: get_manip_fun('meshgrid')(arrays, **kwargs)

# This is just for regular arrays; not masked arrays
def xp_swapaxes(arr, axis1, axis2):
Expand Down
66 changes: 61 additions & 5 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
dtypes_all = dtypes_boolean + dtypes_integral + dtypes_real + dtypes_complex


def get_arrays(n_arrays, *, dtype='float64', xp=np, seed=None):
def get_arrays(n_arrays, *, ndim=(1, 4), dtype='float64', xp=np, seed=None):
xpm = marray.masked_array(xp)

entropy = np.random.SeedSequence(seed).entropy
rng = np.random.default_rng(entropy)

ndim = rng.integers(1, 4)
ndim = rng.integers(*ndim) if isinstance(ndim, tuple) else ndim
shape = rng.integers(1, 20, size=ndim)

datas = []
Expand Down Expand Up @@ -60,7 +60,7 @@ def assert_comparison(res, ref, seed, comparison, **kwargs):
raise AssertionError(seed) from e


def assert_equal(res, ref, seed, **kwargs):
def assert_equal(res, ref, seed=None, **kwargs):
return assert_comparison(res, ref, seed, np.testing.assert_equal, **kwargs)


Expand Down Expand Up @@ -413,10 +413,66 @@ def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None):
ref = np.ma.masked_array(ref.data, getattr(ref, 'mask', False))
assert_equal(res, ref, seed)

# Test Linear Algebra functions

# Test Creation functions
@pytest.mark.parametrize('f_name, args, kwargs', [
# Try to pass options that change output compared to default
('arange', (-1.5, 10, 2), dict(dtype=int)),
('asarray', ([1, 2, 3],), dict(dtype=float, copy=True)),
('empty', ((4, 3, 2),), dict(dtype=int)),
('empty_like', (np.empty((4, 3, 2)),), dict(dtype=int)),
('eye', (10, 11), dict(k=2, dtype=int)),
('full', ((4, 3, 2), 5), dict(dtype=float)),
('full_like', (np.empty((4, 3, 2)), 5.), dict(dtype=int)),
('linspace', (1, 20, 100), dict(dtype=int, endpoint=False)),
('ones', ((4, 3, 2),), dict(dtype=int)),
('ones_like', (np.empty((4, 3, 2)),), dict(dtype=int)),
('zeros', ((4, 3, 2),), dict(dtype=int)),
('zeros_like', (np.empty((4, 3, 2)),), dict(dtype=int)),
])
# Should `_like` functions inherit the mask of the argument?
def test_creation(f_name, args, kwargs, xp=np):
mxp = marray.masked_array(xp)
f_xp = getattr(xp, f_name)
f_mxp = getattr(mxp, f_name)
res = f_mxp(*args, **kwargs)
ref = f_xp(*args, **kwargs)
if f_name.startswith('empty'):
assert res.data.shape == ref.shape
else:
np.testing.assert_equal(res.data, ref, strict=True)
np.testing.assert_equal(res.mask, xp.full(ref.shape, False), strict=True)


@pytest.mark.parametrize('f_name', ['tril', 'triu'])
@pytest.mark.parametrize('dtype', dtypes_all)
def test_tri(f_name, dtype, seed=None, xp=np):
mxp = marray.masked_array(xp)
f_xp = getattr(xp, f_name)
f_mxp = getattr(mxp, f_name)
marrays, _, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, seed=seed)

res = f_mxp(marrays[0], k=1)
ref_data = f_xp(marrays[0].data, k=1)
ref_mask = f_xp(marrays[0].mask, k=1)
ref = np.ma.masked_array(ref_data, mask=ref_mask)
assert_equal(res, ref, seed)


@pytest.mark.parametrize('indexing', ['ij', 'xy'])
@pytest.mark.parametrize('dtype', dtypes_all)
def test_meshgrid(indexing, dtype, seed=None, xp=np):
mxp = marray.masked_array(xp)
marrays, _, seed = get_arrays(1, ndim=1, dtype=dtype, seed=seed)

res = mxp.meshgrid(*marrays, indexing=indexing)
ref_data = xp.meshgrid([marray.data for marray in marrays], indexing=indexing)
ref_mask = xp.meshgrid([marray.mask for marray in marrays], indexing=indexing)
ref = [np.ma.masked_array(data, mask=mask) for data, mask in zip(ref_data, ref_mask)]
[assert_equal(res_array, ref_array, seed) for res_array, ref_array in zip(res, ref)]


# Use Array API tests to test the following:
# Creation Functions (same behavior but with all-False mask)
# Data Type Functions (only `astype` remains to be tested)
# Elementwise function `clip` (all others are tested above)
# Indexing (same behavior as indexing data and mask separately)
Expand Down

0 comments on commit f8afbb1

Please sign in to comment.