Skip to content

Commit

Permalink
TST: test manipulation functions (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber authored Dec 1, 2024
1 parent d1ac508 commit 9962115
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,53 @@ def test_searchsorted(side, xp=strict, seed=None):
# Use Array API tests to test the following:
# Creation Functions (same behavior but with all-False mask)
# Indexing (same behavior as indexing data and mask separately)
# Manipulation functions (apply to data and mask separately)


@pytest.mark.parametrize('f_name, n_arrays, n_dims, kwargs', [
# Try to pass options that change output compared to default
('broadcast_arrays', 3, (3, 5), dict()),
('broadcast_to', 1, (3, 5), dict(shape=None)),
('concat', 3, (3, 5), dict(axis=1)),
('expand_dims', 1, (3, 5), dict(axis=1)),
('flip', 1, (3, 5), dict(axis=1)),
('moveaxis', 1, (3, 5), dict(source=1, destination=2)),
('permute_dims', 1, 3, dict(axes=[2, 0, 1])),
('repeat', 1, (3, 5), dict(repeats=2, axis=1)),
('reshape', 1, (3, 5), dict(shape=(-1,), copy=False)),
('roll', 1, (3, 5), dict(shift=3, axis=1)),
('squeeze', 1, (3, 5), dict(axis=1)),
('stack', 3, (3, 5), dict(axis=1)),
('tile', 1, (3, 5), dict(reps=(2, 3))),
('unstack', 1, (3, 5), dict(axis=1)),
])
def test_creation(f_name, n_arrays, n_dims, kwargs, seed=None, xp=np):
mxp = marray.masked_array(xp)
marrays, _, seed = get_arrays(n_arrays, ndim=n_dims, dtype=xp.float64, seed=seed)
if f_name in {'broadcast_to', 'squeeze'}:
original_shape = marrays[0].shape
marrays[0] = marrays[0][:, 0:1, ...]
if f_name == "broadcast_to":
kwargs['shape'] = original_shape

f_mxp = getattr(mxp, f_name)
f_xp = getattr(xp, f_name)

if f_name in {'concat', 'stack'}:
marrays = mxp.broadcast_arrays(*marrays)
res = (f_mxp(marrays, **kwargs))
ref_data = f_xp([marray.data for marray in marrays], **kwargs)
ref_mask = f_xp([marray.mask for marray in marrays], **kwargs)
else:
res = f_mxp(*marrays, **kwargs)
ref_data = f_xp(*[marray.data for marray in marrays], **kwargs)
ref_mask = f_xp(*[marray.mask for marray in marrays], **kwargs)

ref = np.ma.masked_array(ref_data, mask=ref_mask)

if f_name in {'broadcast_arrays', 'unstack'}:
[assert_equal(res_i, ref_i, seed) for res_i, ref_i in zip(res, ref)]
else:
assert_equal(res, ref, seed)


@pytest.mark.filterwarnings('ignore::numpy.exceptions.ComplexWarning')
Expand Down

0 comments on commit 9962115

Please sign in to comment.