Skip to content

Commit

Permalink
TST: test sorting functions (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber authored Dec 1, 2024
1 parent 9962115 commit ee51465
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
30 changes: 0 additions & 30 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,29 +280,6 @@ def take(x, indices, /, *, axis=None):
return MaskedArray(data, mask=mask)
mod.take = take

def xp_take_along_axis(arr, indices, axis):
# This is just for regular arrays; not masked arrays
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)

m = arr.shape[-1]
n = indices.shape[-1]

shape = list(arr.shape)
shape.pop(-1)
shape = shape + [n,]

arr = xp.reshape(arr, (-1,))
indices = xp.reshape(indices, (-1, n))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))

out = arr[indices]
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
mod._xp_take_along_axis = xp_take_along_axis

## Inspection ##
# Included with dtype functions above

Expand Down Expand Up @@ -366,13 +343,6 @@ def manip_fun(x, *args, **kwargs):
mod.broadcast_arrays = lambda *arrays: get_manip_fun('broadcast_arrays')(arrays)
mod.meshgrid = lambda *arrays, **kwargs: get_manip_fun('meshgrid')(arrays, **kwargs)

# This is just for regular arrays; not masked arrays
def xp_swapaxes(arr, axis1, axis2):
axes = list(range(arr.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
return xp.permute_dims(arr, axes)
mod.xp_swapaxes = xp_swapaxes

## Searching Functions
def searchsorted(x1, x2, /, *, side='left', sorter=None):
if sorter is not None:
Expand Down
34 changes: 33 additions & 1 deletion marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,41 @@ def test_clip(dtype, xp=np, seed=None):

#?
# Set functions
# Sorting functions
# __array_namespace__


@pytest.mark.parametrize("f_name", ['sort', 'argsort'])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("stable", [False]) # NumPy masked arrays don't support True
@pytest.mark.parametrize('dtype', dtypes_real)
def test_sorting(f_name, descending, stable, dtype, xp=strict, seed=None):
mxp = marray.masked_array(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, seed=seed)
f_mxp = getattr(mxp, f_name)
f_xp = getattr(np.ma, f_name)
res = f_mxp(marrays[0], axis=-1, descending=descending, stable=stable)
if descending:
ref = f_xp(-masked_arrays[0], axis=-1, stable=stable)
ref = -ref if f_name=='sort' else ref
else:
ref = f_xp(masked_arrays[0], axis=-1, stable=stable)

if f_name == 'sort':
assert_equal(res, np.ma.masked_array(ref), seed)
else:
# We can't just compare the indices because sometimes `np.ma.argsort`
# doesn't sort the masked elements the same way. Instead, we use the
# indices to sort the arrays, then compare the sorted masked arrays.
# (The difference is that we don't compare the masked values.)
i_sorted = np.asarray(res.data)
res_data = np.take_along_axis(marrays[0].data, i_sorted, axis=-1)
res_mask = np.take_along_axis(marrays[0].mask, i_sorted, axis=-1)
res = mxp.asarray(res_data, mask=res_mask)
ref_data = np.take_along_axis(masked_arrays[0].data, ref, axis=-1)
ref_mask = np.take_along_axis(masked_arrays[0].mask, ref, axis=-1)
ref = np.ma.MaskedArray(ref_data, mask=ref_mask)
assert_equal(res, ref, seed)

def test_test():
seed = 149020664425889521094089537542803361848
# test_statistical_array('argmin', True, seed=seed)
Expand Down

0 comments on commit ee51465

Please sign in to comment.