diff --git a/marray/__init__.py b/marray/__init__.py index d98ed8a..4e3a906 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -4,8 +4,8 @@ __version__ = "0.0.4" -import numpy as np # temporarily used in __repr__ and __str__ - +import types, sys +import dataclasses def masked_array(xp): """Returns a masked array namespace for an array API backend @@ -30,7 +30,8 @@ def __init__(self, data, mask=None): data = xp.asarray(getattr(data, '_data', data)) mask = (xp.zeros(data.shape, dtype=xp.bool) if mask is None else xp.asarray(mask, dtype=xp.bool)) - mask = xp.asarray(xp.broadcast_to(mask, data.shape), copy=True) + if mask.shape != data.shape: # avoid copy if possible + mask = xp.asarray(xp.broadcast_to(mask, data.shape), copy=True) self._data = data self._dtype = data.dtype self._device = data.device @@ -86,20 +87,20 @@ def __setitem__(self, key, other): self.mask[key] = getattr(other, 'mask', False) return self.data.__setitem__(key, getattr(other, 'data', other)) + def _data_mask_string(self, fun): + data_str = fun(self.data) + mask_str = fun(self.mask) + if len(data_str) + len(mask_str) <= 66: + return f"MaskedArray({data_str}, {mask_str})" + else: + return f"MaskedArray(\n {data_str},\n {mask_str}\n)" + ## Visualization ## def __repr__(self): - # temporary: fix for CuPy - # eventually: rewrite to avoid masked array - data = np.asarray(self.data) - mask = np.asarray(self.mask) - return np.ma.masked_array(data, mask).__repr__() + return self._data_mask_string(repr) def __str__(self): - # temporary: fix for CuPy - # eventually: rewrite to avoid masked array - data = np.asarray(self.data) - mask = np.asarray(self.mask) - return np.ma.masked_array(data, mask).__str__() + return self._data_mask_string(str) ## Linear Algebra Methods ## def __matmul__(self, other): @@ -125,6 +126,19 @@ def T(self): def mT(self): return MaskedArray(self.data.mT, self.mask.mT) + # dlpack + def __dlpack_device__(self): + return self.data.__dlpack_device__() + + def __dlpack__(self): + # really not sure how to define this + return self.data.__dlpack__() + + def to_device(self, device, /, *, stream=None): + self._data = self._data.to_device(device, stream=stream) + self._mask = self._mask.to_device(device, stream=stream) + + ## Methods ## # Methods that return the result of a unary operation as an array @@ -172,21 +186,20 @@ def fun(self, other, name=name, **kwargs): return self setattr(MaskedArray, name, fun) - # To be added - # __dlpack__, __dlpack_device__ - # to_device? - def info(x): xp = x._xp if xp.isdtype(x.dtype, 'integral'): return xp.iinfo(x.dtype) + elif xp.isdtype(x.dtype, 'bool'): + binfo = dataclasses.make_dataclass("binfo", ['min', 'max']) + return binfo(min=False, max=True) else: return xp.finfo(x.dtype) - class module: - pass + mod = types.ModuleType('mxp') + sys.modules['mxp'] = mod - mod = module() + mod.MaskedArray = MaskedArray ## Constants ## constant_names = ['e', 'inf', 'nan', 'newaxis', 'pi'] @@ -199,17 +212,20 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None): raise NotImplementedError() data = getattr(obj, 'data', obj) + data = xp.asarray(data, dtype=dtype, device=device, copy=copy) + mask = (getattr(obj, 'mask', xp.full(data.shape, False)) if mask is None else mask) + mask = xp.asarray(mask, dtype=xp.bool, device=device, copy=copy) - data = xp.asarray(data, dtype=dtype, device=device, copy=copy) - mask = xp.asarray(mask, dtype=dtype, device=device, copy=copy) return MaskedArray(data, mask=mask) mod.asarray = asarray creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack', - 'full', 'full_like', 'linspace', 'meshgrid', 'ones', - 'ones_like', 'tril', 'triu', 'zeros', 'zeros_like'] + 'full', 'full_like', 'linspace', 'ones', 'ones_like', + 'zeros', 'zeros_like'] + # handled with array manipulation functions + creation_manip_functions = ['tril', 'triu', 'meshgrid'] for name in creation_functions: def fun(*args, name=name, **kwargs): data = getattr(xp, name)(*args, **kwargs) @@ -233,7 +249,7 @@ def fun(*args, name=name, **kwargs): elementwise_names = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', - 'bitwise_xor', 'ceil', 'clip', 'conj', 'copysign', 'cos', + 'bitwise_xor', 'ceil', 'conj', 'copysign', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'imag', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', @@ -252,32 +268,25 @@ def fun(*args, name=name, **kwargs): return MaskedArray(out, mask=xp.any(masks, axis=0)) setattr(mod, name, fun) - ## Indexing Functions - # To be written: - # take - - def xp_take_along_axis(arr, indices, axis): - # This is just for regular arrays; not masked arrays - arr = xp_swapaxes(arr, axis, -1) - indices = xp_swapaxes(indices, axis, -1) - - m = arr.shape[-1] - n = indices.shape[-1] - - shape = list(arr.shape) - shape.pop(-1) - shape = shape + [n,] - arr = xp.reshape(arr, (-1,)) - indices = xp.reshape(indices, (-1, n)) - - offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] - indices = xp.reshape(offset + indices, (-1,)) + def clip(x, /, min=None, max=None): + args = [x, min, max] + masks = [arg.mask for arg in args if hasattr(arg, 'mask')] + masks = xp.broadcast_arrays(*masks) + mask = xp.any(masks, axis=0) + datas = [getattr(arg, 'data', arg) for arg in args] + data = xp.clip(datas[0], min=datas[1], max=datas[2]) + return MaskedArray(data, mask) + mod.clip = clip - out = arr[indices] - out = xp.reshape(out, shape) - return xp_swapaxes(out, axis, -1) - mod._xp_take_along_axis = xp_take_along_axis + ## Indexing Functions + def take(x, indices, /, *, axis=None): + indices_data = getattr(indices, 'data', indices) + indices_mask = getattr(indices, 'mask', False) + data = xp.take(x.data, indices_data, axis=axis) + mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask + return MaskedArray(data, mask=mask) + mod.take = take ## Inspection ## # Included with dtype functions above @@ -308,8 +317,8 @@ def linalg_fun(x1, x2, /, **kwargs): mod.matrix_transpose = lambda x: x.mT ## Manipulation Functions ## - first_arg_arrays = {'broadcast_arrays', 'concat', 'stack'} - output_arrays = {'broadcast_arrays', 'unstack'} + first_arg_arrays = {'broadcast_arrays', 'concat', 'stack', 'meshgrid'} + output_arrays = {'broadcast_arrays', 'unstack', 'meshgrid'} def get_manip_fun(name): def manip_fun(x, *args, **kwargs): @@ -322,7 +331,7 @@ def manip_fun(x, *args, **kwargs): fun = getattr(xp, name) - if name == 'broadcast_arrays': + if name in {'broadcast_arrays', 'meshgrid'}: res = fun(*data, *args, **kwargs) mask = fun(*mask, *args, **kwargs) else: @@ -337,20 +346,24 @@ def manip_fun(x, *args, **kwargs): manip_names = ['broadcast_arrays', 'broadcast_to', 'concat', 'expand_dims', 'flip', 'moveaxis', 'permute_dims', 'repeat', 'reshape', 'roll', 'squeeze', 'stack', 'tile', 'unstack'] - for name in manip_names: + for name in manip_names + creation_manip_functions: setattr(mod, name, get_manip_fun(name)) mod.broadcast_arrays = lambda *arrays: get_manip_fun('broadcast_arrays')(arrays) - - # This is just for regular arrays; not masked arrays - def xp_swapaxes(arr, axis1, axis2): - axes = list(range(arr.ndim)) - axes[axis1], axes[axis2] = axes[axis2], axes[axis1] - return xp.permute_dims(arr, axes) - mod.xp_swapaxes = xp_swapaxes + mod.meshgrid = lambda *arrays, **kwargs: get_manip_fun('meshgrid')(arrays, **kwargs) ## Searching Functions - # To be added - # searchsorted + def searchsorted(x1, x2, /, *, side='left', sorter=None): + if sorter is not None: + x1 = take(x1, sorter) + + mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64)) + x1_compressed = x1.data[~x1.mask] + count = xp.zeros(x1_compressed.size+1, dtype=xp.int64) + count[:-1] = mask_count[~x1.mask] + count[-1] = count[-2] + i = xp.searchsorted(x1_compressed, x2.data, side=side) + j = i + xp.take(count, i) + return MaskedArray(j, mask=x2.mask) def nonzero(x, /): x = asarray(x) @@ -367,6 +380,7 @@ def where(condition, x1, x2, /): mask = condition.mask | x1.mask | x2.mask return MaskedArray(data, mask) + mod.searchsorted = searchsorted mod.nonzero = nonzero mod.where = where diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index be9a2ca..5db664b 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -65,8 +65,8 @@ def assert_comparison(res, ref, seed, xp, comparison, **kwargs): assert isinstance(res.mask, array_type) ref_mask = np.broadcast_to(ref.mask, ref.data.shape) try: - comparison(res.data[~res.mask], ref.data[~ref_mask], **kwargs) - comparison(res.mask, ref_mask, **kwargs) + comparison(res.data[~res.mask], ref.data[~ref_mask], strict=True, **kwargs) + comparison(res.mask, ref_mask, strict=True, **kwargs) except AssertionError as e: raise AssertionError(seed) from e @@ -129,17 +129,13 @@ def get_rtol(dtype, xp): 'bitwise_right_shift': lambda x, y: x.__rshift__(y), 'bitwise_xor': lambda x, y: x.__xor__(y)} -# __array_namespace__ -# __bool__ -# __complex__ -# __dlpack__ -# __dlpack_device__ -# __float__ -# __getitem__ -# __index__ -# __int__ -# __setitem__ -# __to_device__ + +scalar_conversions = {bool: True, int: 10, float: 1.5, complex: 1.5 + 2.5j} + +# tested in test_dlpack +# __dlpack__, __dlpack_device__, to_device +# tested in test_indexing +# __getitem__, __index__, __setitem__, comparison_binary = [lambda x, y: x < y, lambda x, y: x <= y, lambda x, y: x > y, lambda x, y: x >= y, lambda x, y: x == y , lambda x, y: x != y] @@ -182,7 +178,7 @@ def irshift(x, y): x >>= y 'logaddexp', 'logical_and', 'logical_or', 'logical_xor', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract'] -searching_array = ['argmax', 'argmin'] # NumPy masked array funcs not good references +searching_array = ['argmax', 'argmin'] statistical_array = ['cumulative_sum', 'max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] utility_array = ['all', 'any'] @@ -271,6 +267,57 @@ def test_bitwise_binary(f_name_fun, dtype, xp=np, seed=None): assert_equal(res, ref, seed) +@pytest.mark.parametrize('type_val', scalar_conversions.items()) +@pytest.mark.parametrize('mask', [False, True]) +def test_scalar_conversion(type_val, mask, xp=np): + mxp = marray.masked_array(xp) + type, val = type_val + x = mxp.asarray(val) + assert type(x) == val + assert isinstance(type(x), type) + + method = getattr(x, f"__{type.__name__}__") + assert method() == val + assert isinstance(method(), type) + + +def test_indexing(xp=strict): + # The implementations of `__getitem__` and `__setitem__` are trivial. + # This does not make them easy to test exhaustively, but it does make + # them easy to fix if a shortcoming is identified. Include a very basic + # test for now, and improve as needed. + mxp = marray.masked_array(xp) + x = mxp.asarray(xp.arange(3), mask=[False, True, False]) + + # Test `__setitem__`/`__getitem__` roundtrip + x[1] = 10 + assert x[1] == 10 + assert isinstance(x[1], type(x)) + + # Test `__setitem__`/`__getitem__` roundtrip with masked array as index + i = mxp.asarray(1, mask=True) + x[i.__index__()] = 20 + assert x[i.__index__()] == 20 + assert isinstance(x[i.__index__()], type(x)) + + # `__setitem__` can change mask + x[1] = mxp.asarray(30, mask=False) + assert x[1].data == 30 + assert x[1].mask == False + x[2] = mxp.asarray(40, mask=True) + assert x[2].data == 40 + assert x[2].mask == True + + +def test_dlpack(xp=strict, seed=None): + # This is a placeholder for a real test when there is a real implementation + mxp = marray.masked_array(xp) + marrays, _, seed = get_arrays(1, seed=seed) + assert isinstance(marrays[0].__dlpack__(), type(marrays[0].data.__dlpack__())) + assert marrays[0].__dlpack_device__() == marrays[0].data.__dlpack_device__() + marrays[0].to_device('cpu') + + @pytest.mark.parametrize("dtype", dtypes_integral + dtypes_real) @pytest.mark.parametrize("f", comparison_binary + comparison_methods_binary) def test_comparison_binary(f, dtype, seed=None): @@ -442,7 +489,7 @@ def test_elementwise_binary(f_name, xp=np, dtype='float64', seed=None): @pytest.mark.parametrize("keepdims", [False, True]) -@pytest.mark.parametrize("f_name", statistical_array + utility_array) +@pytest.mark.parametrize("f_name", statistical_array + utility_array + searching_array) def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None): # TODO: confirm that result should never have mask? Only when all are masked? mxp = marray.masked_array(xp) @@ -458,25 +505,240 @@ def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None): f2 = getattr(xp, f_name2) res = f(marrays[0], axis=axis, **kwargs) ref = f2(masked_arrays[0], axis=axis, **kwargs) - ref = np.ma.masked_array(ref.data, getattr(ref, 'mask', False)) + # `argmin`/`argmax` don't calculate mask correctly + ref_mask = np.all(masked_arrays[0].mask, axis=axis, **kwargs) + ref = np.ma.masked_array(ref.data, getattr(ref, 'mask', ref_mask)) assert_equal(res, ref, seed) + +# Test Creation functions +@pytest.mark.parametrize('f_name, args, kwargs', [ + # Try to pass options that change output compared to default + ('arange', (-1.5, 10, 2), dict(dtype=int)), + ('asarray', ([1, 2, 3],), dict(dtype=float, copy=True)), + ('empty', ((4, 3, 2),), dict(dtype=int)), + ('empty_like', (np.empty((4, 3, 2)),), dict(dtype=int)), + ('eye', (10, 11), dict(k=2, dtype=int)), + ('full', ((4, 3, 2), 5), dict(dtype=float)), + ('full_like', (np.empty((4, 3, 2)), 5.), dict(dtype=int)), + ('linspace', (1, 20, 100), dict(dtype=int, endpoint=False)), + ('ones', ((4, 3, 2),), dict(dtype=int)), + ('ones_like', (np.empty((4, 3, 2)),), dict(dtype=int)), + ('zeros', ((4, 3, 2),), dict(dtype=int)), + ('zeros_like', (np.empty((4, 3, 2)),), dict(dtype=int)), +]) +# Should `_like` functions inherit the mask of the argument? +def test_creation(f_name, args, kwargs, xp=np): + mxp = marray.masked_array(xp) + f_xp = getattr(xp, f_name) + f_mxp = getattr(mxp, f_name) + res = f_mxp(*args, **kwargs) + ref = f_xp(*args, **kwargs) + if f_name.startswith('empty'): + assert res.data.shape == ref.shape + else: + np.testing.assert_equal(res.data, ref, strict=True) + np.testing.assert_equal(res.mask, xp.full(ref.shape, False), strict=True) + + +@pytest.mark.parametrize('f_name', ['tril', 'triu']) +@pytest.mark.parametrize('dtype', dtypes_all) +def test_tri(f_name, dtype, seed=None, xp=np): + mxp = marray.masked_array(xp) + f_xp = getattr(xp, f_name) + f_mxp = getattr(mxp, f_name) + marrays, _, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, seed=seed) + + res = f_mxp(marrays[0], k=1) + ref_data = f_xp(marrays[0].data, k=1) + ref_mask = f_xp(marrays[0].mask, k=1) + ref = np.ma.masked_array(ref_data, mask=ref_mask) + assert_equal(res, ref, seed) + + +@pytest.mark.parametrize('indexing', ['ij', 'xy']) +@pytest.mark.parametrize('dtype', dtypes_all) +def test_meshgrid(indexing, dtype, seed=None, xp=np): + mxp = marray.masked_array(xp) + marrays, _, seed = get_arrays(1, ndim=1, dtype=dtype, seed=seed) + + res = mxp.meshgrid(*marrays, indexing=indexing) + ref_data = xp.meshgrid([marray.data for marray in marrays], indexing=indexing) + ref_mask = xp.meshgrid([marray.mask for marray in marrays], indexing=indexing) + ref = [np.ma.masked_array(data, mask=mask) for data, mask in zip(ref_data, ref_mask)] + [assert_equal(res_array, ref_array, seed) for res_array, ref_array in zip(res, ref)] + + +@pytest.mark.parametrize("side", ['left', 'right']) +def test_searchsorted(side, xp=strict, seed=None): + mxp = marray.masked_array(xp) + + rng = np.random.default_rng(seed) + n = 20 + m = 10 + + x1 = rng.integers(10, size=n) + x1_mask = (rng.random(size=n) > 0.5) + x2 = rng.integers(-2, 12, size=m) + x2_mask = rng.random(size=m) > 0.5 + + x1 = mxp.asarray(x1, mask=x1_mask) + x2 = mxp.asarray(x2, mask=x2_mask) + + # Note that the output of `searchsorted` is the same whether + # a (valid) `sorter` is provided or the array is sorted to begin with + res = xp.searchsorted(x1.data, x2.data, side=side, sorter=xp.argsort(x1.data)) + ref = xp.searchsorted(xp.sort(x1.data), x2.data, side=side, sorter=None) + assert xp.all(res == ref) + + # This is true for `marray`, too + res = mxp.searchsorted(x1, x2, side=side, sorter=mxp.argsort(x1)) + x1 = mxp.sort(x1) + ref = mxp.searchsorted(x1, x2, side=side, sorter=None) + assert mxp.all(res == ref) + + # And the output satisfies the required properties: + for j in range(res.size): + i = res[j] + + if i.mask: + assert x2.mask[j] + continue + + i = i.__index__() + v = x2[j] + if side == 'left': + assert mxp.all(x1[:i] < v) and mxp.all(v <= x1[i:]) + else: + assert mxp.all(x1[:i] <= v) and mxp.all(v < x1[i:]) + + # Test Linear Algebra functions # Use Array API tests to test the following: # Creation Functions (same behavior but with all-False mask) -# Data Type Functions (only `astype` remains to be tested) -# Elementwise function `clip` (all others are tested above) # Indexing (same behavior as indexing data and mask separately) -# Manipulation functions (apply to data and mask separately) + + +@pytest.mark.parametrize('f_name, n_arrays, n_dims, kwargs', [ + # Try to pass options that change output compared to default + ('broadcast_arrays', 3, (3, 5), dict()), + ('broadcast_to', 1, (3, 5), dict(shape=None)), + ('concat', 3, (3, 5), dict(axis=1)), + ('expand_dims', 1, (3, 5), dict(axis=1)), + ('flip', 1, (3, 5), dict(axis=1)), + ('moveaxis', 1, (3, 5), dict(source=1, destination=2)), + ('permute_dims', 1, 3, dict(axes=[2, 0, 1])), + ('repeat', 1, (3, 5), dict(repeats=2, axis=1)), + ('reshape', 1, (3, 5), dict(shape=(-1,), copy=False)), + ('roll', 1, (3, 5), dict(shift=3, axis=1)), + ('squeeze', 1, (3, 5), dict(axis=1)), + ('stack', 3, (3, 5), dict(axis=1)), + ('tile', 1, (3, 5), dict(reps=(2, 3))), + ('unstack', 1, (3, 5), dict(axis=1)), +]) +def test_creation(f_name, n_arrays, n_dims, kwargs, seed=None, xp=np): + mxp = marray.masked_array(xp) + marrays, _, seed = get_arrays(n_arrays, ndim=n_dims, dtype=xp.float64, seed=seed) + if f_name in {'broadcast_to', 'squeeze'}: + original_shape = marrays[0].shape + marrays[0] = marrays[0][:, 0:1, ...] + if f_name == "broadcast_to": + kwargs['shape'] = original_shape + + f_mxp = getattr(mxp, f_name) + f_xp = getattr(xp, f_name) + + if f_name in {'concat', 'stack'}: + marrays = mxp.broadcast_arrays(*marrays) + res = (f_mxp(marrays, **kwargs)) + ref_data = f_xp([marray.data for marray in marrays], **kwargs) + ref_mask = f_xp([marray.mask for marray in marrays], **kwargs) + else: + res = f_mxp(*marrays, **kwargs) + ref_data = f_xp(*[marray.data for marray in marrays], **kwargs) + ref_mask = f_xp(*[marray.mask for marray in marrays], **kwargs) + + ref = np.ma.masked_array(ref_data, mask=ref_mask) + + if f_name in {'broadcast_arrays', 'unstack'}: + [assert_equal(res_i, ref_i, seed) for res_i, ref_i in zip(res, ref)] + else: + assert_equal(res, ref, seed) + + +@pytest.mark.filterwarnings('ignore::numpy.exceptions.ComplexWarning') +@pytest.mark.parametrize('dtype_in', dtypes_all) +@pytest.mark.parametrize('dtype_out', dtypes_all) +@pytest.mark.parametrize('copy', [False, True]) +def test_astype(dtype_in, dtype_out, copy, xp=np, seed=None): + mxp = marray.masked_array(xp) + marrays, masked_arrays, seed = get_arrays(1, dtype=dtype_in, seed=seed) + + res = mxp.astype(marrays[0], dtype_out, copy=copy) + if dtype_in == dtype_out: + if copy: + assert res.data is not marrays[0].data + assert res.mask is not marrays[0].mask + else: + assert res.data is marrays[0].data + assert res.mask is marrays[0].mask + ref = masked_arrays[0].astype(dtype_out, copy=copy) + assert_equal(res, ref, seed) + + +@pytest.mark.parametrize('dtype', dtypes_real) +def test_clip(dtype, xp=np, seed=None): + mxp = marray.masked_array(xp) + marrays, masked_arrays, seed = get_arrays(3, dtype=dtype, seed=seed) + res = mxp.clip(marrays[0], min=marrays[1], max=marrays[2]) + ref = np.ma.clip(*masked_arrays) + assert_equal(res, ref, seed) #? -# Searching functions - would test argmin/argmax with statistical functions, -# but NumPy masked version isn't correct # Set functions -# Sorting functions # __array_namespace__ + +@pytest.mark.parametrize("f_name", ['sort', 'argsort']) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("stable", [False]) # NumPy masked arrays don't support True +@pytest.mark.parametrize('dtype', dtypes_real) +def test_sorting(f_name, descending, stable, dtype, xp=strict, seed=None): + mxp = marray.masked_array(xp) + marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, seed=seed) + f_mxp = getattr(mxp, f_name) + f_xp = getattr(np.ma, f_name) + res = f_mxp(marrays[0], axis=-1, descending=descending, stable=stable) + if descending: + ref = f_xp(-masked_arrays[0], axis=-1, stable=stable) + ref = -ref if f_name=='sort' else ref + else: + ref = f_xp(masked_arrays[0], axis=-1, stable=stable) + + if f_name == 'sort': + assert_equal(res, np.ma.masked_array(ref), seed) + else: + # We can't just compare the indices because sometimes `np.ma.argsort` + # doesn't sort the masked elements the same way. Instead, we use the + # indices to sort the arrays, then compare the sorted masked arrays. + # (The difference is that we don't compare the masked values.) + i_sorted = np.asarray(res.data) + res_data = np.take_along_axis(marrays[0].data, i_sorted, axis=-1) + res_mask = np.take_along_axis(marrays[0].mask, i_sorted, axis=-1) + res = mxp.asarray(res_data, mask=res_mask) + ref_data = np.take_along_axis(masked_arrays[0].data, ref, axis=-1) + ref_mask = np.take_along_axis(masked_arrays[0].mask, ref, axis=-1) + ref = np.ma.MaskedArray(ref_data, mask=ref_mask) + assert_equal(res, ref, seed) + + +def test_import(xp=np): + mxp = marray.masked_array(xp) + from mxp import asarray + asarray(10, mask=True) + + def test_test(): seed = 149020664425889521094089537542803361848 # test_statistical_array('argmin', True, seed=seed) diff --git a/pyproject.toml b/pyproject.toml index d1e4d0e..1fadcdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,11 +14,11 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dynamic = ["version", "description"] -dependencies = ["numpy"] +dependencies = [] requires-python = ">=3.10" [project.optional-dependencies] -test = ["pytest"] +test = ["numpy", "pytest"] [project.urls] Home = "https://github.com/mdhaber/marray"