diff --git a/marray/__init__.py b/marray/__init__.py index 42987b0..5ad4f71 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -7,24 +7,22 @@ import types, sys import dataclasses -def masked_array(xp): - """Returns a masked array namespace for an array API backend +def get_namespace(xp): + """Returns a masked array namespace for an Array API Standard compatible backend Examples -------- - >>> from scipy._lib.array_api_compat import numpy as xp - >>> from scipy.stats import masked_array - >>> ma = masked_array(xp) - >>> A = ma.eye(3) + >>> import numpy as xp + >>> from marray import get_namespace + >>> mxp = get_namespace(xp) + >>> A = mxp.eye(3) >>> A.mask[0, ...] = True - >>> x = ma.asarray([1, 2, 3], mask=[False, False, True]) + >>> x = mxp.asarray([1, 2, 3], mask=[False, False, True]) >>> A @ x - masked_array(data=[--, 2.0, 0.0], - mask=[ True, False, False], - fill_value=1e+20) + MArray(array([0., 2., 0.]), array([ True, False, False])) """ - class MaskedArray: + class MArray: def __init__(self, data, mask=None): data = xp.asarray(getattr(data, '_data', data)) @@ -80,7 +78,7 @@ def call_super_method(self, method_name, *args, **kwargs): ## Indexing ## def __getitem__(self, key): - return MaskedArray(self.data[key], self.mask[key]) + return MArray(self.data[key], self.mask[key]) def __setitem__(self, key, other): self.mask[key] = getattr(other, 'mask', False) @@ -90,9 +88,9 @@ def _data_mask_string(self, fun): data_str = fun(self.data) mask_str = fun(self.mask) if len(data_str) + len(mask_str) <= 66: - return f"MaskedArray({data_str}, {mask_str})" + return f"MArray({data_str}, {mask_str})" else: - return f"MaskedArray(\n {data_str},\n {mask_str}\n)" + return f"MArray(\n {data_str},\n {mask_str}\n)" ## Visualization ## def __repr__(self): @@ -112,18 +110,18 @@ def __imatmul__(self, other): return def __rmatmul__(self, other): - other = MaskedArray(other) + other = MArray(other) return mod.matmul(other, self) ## Attributes ## @property def T(self): - return MaskedArray(self.data.T, self.mask.T) + return MArray(self.data.T, self.mask.T) @property def mT(self): - return MaskedArray(self.data.mT, self.mask.mT) + return MArray(self.data.mT, self.mask.mT) # dlpack def __dlpack_device__(self): @@ -146,15 +144,15 @@ def to_device(self, device, /, *, stream=None): for name in unary_names: def fun(self, name=name): data = self.call_super_method(name) - return MaskedArray(data, self.mask) - setattr(MaskedArray, name, fun) + return MArray(data, self.mask) + setattr(MArray, name, fun) # Methods that return the result of a unary operation as a Python scalar unary_names_py = ['__bool__', '__complex__', '__float__', '__index__', '__int__'] for name in unary_names_py: def fun(self, name=name): return self.call_super_method(name) - setattr(MaskedArray, name, fun) + setattr(MArray, name, fun) # Methods that return the result of an elementwise binary operation binary_names = ['__add__', '__sub__', '__and__', '__eq__', '__ge__', '__gt__', @@ -169,8 +167,8 @@ def fun(self, name=name): def fun(self, other, name=name): mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask data = self.call_super_method(name, other) - return MaskedArray(data, mask) - setattr(MaskedArray, name, fun) + return MArray(data, mask) + setattr(MArray, name, fun) # In-place methods desired_names = ['__iadd__', '__iand__', '__ifloordiv__', '__ilshift__', @@ -183,7 +181,7 @@ def fun(self, other, name=name, **kwargs): self.mask.__ior__(other.mask) self.call_super_method(name, other) return self - setattr(MaskedArray, name, fun) + setattr(MArray, name, fun) def info(x): xp = x._xp @@ -198,7 +196,7 @@ def info(x): mod = types.ModuleType('mxp') sys.modules['mxp'] = mod - mod.MaskedArray = MaskedArray + mod.MArray = MArray ## Constants ## constant_names = ['e', 'inf', 'nan', 'newaxis', 'pi'] @@ -217,7 +215,7 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None): if mask is None else mask) mask = xp.asarray(mask, dtype=xp.bool, device=device, copy=copy) - return MaskedArray(data, mask=mask) + return MArray(data, mask=mask) mod.asarray = asarray creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack', @@ -228,7 +226,7 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None): for name in creation_functions: def fun(*args, name=name, **kwargs): data = getattr(xp, name)(*args, **kwargs) - return MaskedArray(data) + return MArray(data) setattr(mod, name, fun) ## Data Type Functions and Data Types ## @@ -246,7 +244,7 @@ def astype(x, dtype, /, *, copy=True, device=None): return x data = xp.astype(x.data, dtype, copy=copy, device=device) mask = xp.astype(x.mask, xp.bool, copy=copy, device=device) - return MaskedArray(data, mask=mask) + return MArray(data, mask=mask) mod.astype = astype ## Elementwise Functions ## @@ -269,7 +267,7 @@ def fun(*args, name=name, **kwargs): masks = xp.broadcast_arrays(*masks) args = [getattr(arg, 'data', arg) for arg in args] out = getattr(xp, name)(*args, **kwargs) - return MaskedArray(out, mask=xp.any(xp.stack(masks), axis=0)) + return MArray(out, mask=xp.any(xp.stack(masks), axis=0)) setattr(mod, name, fun) @@ -280,7 +278,7 @@ def clip(x, /, min=None, max=None): mask = xp.any(xp.stack(masks), axis=0) datas = [getattr(arg, 'data', arg) for arg in args] data = xp.clip(datas[0], min=datas[1], max=datas[2]) - return MaskedArray(data, mask) + return MArray(data, mask) mod.clip = clip ## Indexing Functions @@ -289,7 +287,7 @@ def take(x, indices, /, *, axis=None): indices_mask = getattr(indices, 'mask', False) data = xp.take(x.data, indices_data, axis=axis) mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask - return MaskedArray(data, mask=mask) + return MArray(data, mask=mask) mod.take = take ## Inspection ## @@ -311,7 +309,7 @@ def linalg_fun(x1, x2, /, **kwargs): mask = fun(xp.astype(~x1.mask, xp.uint64), xp.astype(~x2.mask, xp.uint64)) mask = ~xp.astype(mask, xp.bool) - return MaskedArray(data, mask) + return MArray(data, mask) return linalg_fun linalg_names = ['matmul', 'tensordot', 'vecdot'] @@ -342,8 +340,8 @@ def manip_fun(x, *args, **kwargs): res = fun(data, *args, **kwargs) mask = fun(mask, *args, **kwargs) - out = (MaskedArray(res, mask) if name not in output_arrays - else [MaskedArray(resi, maski) for resi, maski in zip(res, mask)]) + out = (MArray(res, mask) if name not in output_arrays + else [MArray(resi, maski) for resi, maski in zip(res, mask)]) return out return manip_fun @@ -367,14 +365,14 @@ def searchsorted(x1, x2, /, *, side='left', sorter=None): count[-1] = count[-2] i = xp.searchsorted(x1_compressed, x2.data, side=side) j = i + xp.take(count, i) - return MaskedArray(j, mask=x2.mask) + return MArray(j, mask=x2.mask) def nonzero(x, /): x = asarray(x) data = xp.asarray(x.data, copy=True) data[x.mask] = 0 res = xp.nonzero(data) - return tuple(MaskedArray(resi) for resi in res) + return tuple(MArray(resi) for resi in res) def where(condition, x1, x2, /): condition = asarray(condition) @@ -382,7 +380,7 @@ def where(condition, x1, x2, /): x2 = asarray(x2) data = xp.where(condition.data, x1.data, x2.data) mask = condition.mask | x1.mask | x2.mask - return MaskedArray(data, mask) + return MArray(data, mask) mod.searchsorted = searchsorted mod.nonzero = nonzero @@ -403,8 +401,8 @@ def set_fun(x, /): fun = getattr(xp, name) res = fun(data) # this sort of works but could be refined - return (MaskedArray(res, res==x._sentinel) if name=='unique_values' - else tuple(MaskedArray(resi, resi==x._sentinel) for resi in res)) + return (MArray(res, res==x._sentinel) if name=='unique_values' + else tuple(MArray(resi, resi==x._sentinel) for resi in res)) return set_fun unique_names = ['unique_values', 'unique_counts', 'unique_inverse', 'unique_all'] @@ -422,7 +420,7 @@ def sort_fun(x, /, *, axis=-1, descending=False, stable=True): kwargs = {'descending': True} if descending else {} res = fun(data, axis=axis, stable=stable, **kwargs) mask = (res == sentinel) if name=='sort' else None - return MaskedArray(res, mask) + return MArray(res, mask) return sort_fun sort_names = ['sort', 'argsort'] @@ -446,7 +444,7 @@ def statistical_fun(x, *args, axis=None, name=name, **kwargs): fun = getattr(xp, name) res = fun(data, *args, axis=axis, **kwargs) mask = xp.all(x.mask, axis=axis, keepdims=kwargs.get('keepdims', False)) - return MaskedArray(res, mask=mask) + return MArray(res, mask=mask) return statistical_fun def count(x, axis=None, keepdims=False): @@ -459,7 +457,7 @@ def cumulative_sum(x, *args, **kwargs): data = xp.asarray(x.data, copy=True) data[x.mask] = 0 res = xp.cumulative_sum(data, *args, **kwargs) - return MaskedArray(res, x.mask) + return MArray(res, x.mask) def mean(x, axis=None, keepdims=False): s = mod.sum(x, axis=axis, keepdims=keepdims) diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index 77059ee..757ce38 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -20,7 +20,7 @@ def get_arrays(n_arrays, *, dtype, xp, ndim=(1, 4), seed=None): - xpm = marray.masked_array(xp) + xpm = marray.get_namespace(xp) entropy = np.random.SeedSequence(seed).entropy rng = np.random.default_rng(entropy) @@ -238,7 +238,7 @@ def test_array_binary(f, dtype, xp, seed=None): @pytest.mark.parametrize('xp', xps) def test_bitwise_unary(f_name_fun, dtype, xp, seed=None): f_name, f = f_name_fun - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) res = f(~marrays[0]) @@ -258,7 +258,7 @@ def test_bitwise_unary(f_name_fun, dtype, xp, seed=None): "Only integer dtypes are allowed in "]) def test_bitwise_binary(f_name_fun, dtype, xp, seed=None): f_name, f = f_name_fun - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) res = f(marrays[0], marrays[1]) @@ -274,7 +274,7 @@ def test_bitwise_binary(f_name_fun, dtype, xp, seed=None): @pytest.mark.parametrize('mask', [False, True]) @pytest.mark.parametrize('xp', xps) def test_scalar_conversion(type_val, mask, xp): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) type, val = type_val x = mxp.asarray(val) assert type(x) == val @@ -291,7 +291,7 @@ def test_indexing(xp): # This does not make them easy to test exhaustively, but it does make # them easy to fix if a shortcoming is identified. Include a very basic # test for now, and improve as needed. - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) x = mxp.asarray(xp.arange(3), mask=[False, True, False]) # Test `__setitem__`/`__getitem__` roundtrip @@ -318,7 +318,7 @@ def test_indexing(xp): @pytest.mark.parametrize('xp', xps) def test_dlpack(dtype, xp, seed=None): # This is a placeholder for a real test when there is a real implementation - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, _, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) assert isinstance(marrays[0].__dlpack__(), type(marrays[0].data.__dlpack__())) assert marrays[0].__dlpack_device__() == marrays[0].data.__dlpack_device__() @@ -371,7 +371,7 @@ def test_inplace(f, dtype, xp, seed=None): @pass_exceptions(allowed=["Only numeric dtypes are allowed in matmul"]) def test_inplace_array_binary(f, dtype, xp, seed=None): # very restrictive operator -> limited test - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) rng = np.random.default_rng(seed) data = (rng.random((3, 10, 10))*10).astype(dtype) mask = rng.random((3, 10, 10)) > 0.5 @@ -411,9 +411,9 @@ def test_rarithmetic_binary(f, dtype, xp, type_, seed=None): assert_equal(res, ref, xp=xp, seed=seed) -def test_rarray_binary(xp=np, seed=None): +def test_rarray_binary(xp=strict, seed=None): # very restrictive operator -> limited test - mxp = marray.masked_array(strict) + mxp = marray.get_namespace(xp) rng = np.random.default_rng(seed) data = rng.random((3, 10, 10)) mask = rng.random((3, 10, 10)) > 0.5 @@ -459,7 +459,7 @@ def test_attributes(dtype, xp, seed=None): @pytest.mark.parametrize('xp', xps) def test_constants(xp): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) assert mxp.e == xp.e assert mxp.inf == xp.inf assert np.isnan(mxp.nan) == np.isnan(xp.nan) @@ -470,7 +470,7 @@ def test_constants(xp): @pytest.mark.parametrize("f", data_type + inspection + version) @pytest.mark.parametrize('xp', xps) def test_dtype_funcs_inspection(f, xp): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) getattr(mxp, f) is getattr(xp, f) @@ -479,7 +479,7 @@ def test_dtype_funcs_inspection(f, xp): def test_dtypes(dtype, xp): if xp == np: pytest.xfail("NumPy fails... unclear whether NumPy follows standard here.") - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) getattr(mxp, dtype).__eq__(getattr(xp, dtype)) @@ -496,7 +496,7 @@ def test_dtypes(dtype, xp): "Only boolean dtypes are allowed", "Only complex floating-point dtypes are allowed"]) def test_elementwise_unary(f_name, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) f = getattr(mxp, f_name) f2 = getattr(xp, f_name) @@ -519,7 +519,7 @@ def test_elementwise_unary(f_name, dtype, xp, seed=None): "Only numeric dtypes are allowed", "Only boolean dtypes are allowed",]) def test_elementwise_binary(f_name, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) f = getattr(mxp, f_name) f2 = getattr(np, f_name) @@ -538,7 +538,7 @@ def test_elementwise_binary(f_name, dtype, xp, seed=None): "Only numeric dtypes are allowed", "Only real numeric dtypes are allowed"]) def test_statistical_array(f_name, keepdims, xp, dtype, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) rng = np.random.default_rng(seed) axes = list(range(marrays[0].ndim)) @@ -578,7 +578,7 @@ def test_statistical_array(f_name, keepdims, xp, dtype, seed=None): @pass_exceptions(allowed=[r"arange() is only supported for booleans when"]) def test_creation(f_name, args, kwargs, dtype, xp, seed=None): dtype = getattr(xp, dtype) - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) f_xp = getattr(xp, f_name) f_mxp = getattr(mxp, f_name) if f_name.endswith('like'): @@ -596,7 +596,7 @@ def test_creation(f_name, args, kwargs, dtype, xp, seed=None): @pytest.mark.parametrize('dtype', dtypes_all) @pytest.mark.parametrize('xp', xps) def test_tri(f_name, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) f_xp = getattr(xp, f_name) f_mxp = getattr(mxp, f_name) marrays, _, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, xp=xp, seed=seed) @@ -612,7 +612,7 @@ def test_tri(f_name, dtype, xp, seed=None): @pytest.mark.parametrize('dtype', dtypes_all) @pytest.mark.parametrize('xp', [np]) def test_meshgrid(indexing, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, _, seed = get_arrays(1, ndim=1, dtype=dtype, xp=xp, seed=seed) res = mxp.meshgrid(*marrays, indexing=indexing) @@ -626,7 +626,7 @@ def test_meshgrid(indexing, dtype, xp, seed=None): @pytest.mark.parametrize('dtype', dtypes_integral + dtypes_real) @pytest.mark.parametrize('xp', xps) def test_searchsorted(side, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) rng = np.random.default_rng(seed) n = 20 @@ -688,7 +688,7 @@ def test_searchsorted(side, dtype, xp, seed=None): @pytest.mark.parametrize('dtype', dtypes_all) @pytest.mark.parametrize('xp', xps) def test_manipulation(f_name, n_arrays, n_dims, args, kwargs, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, _, seed = get_arrays(n_arrays, ndim=n_dims, dtype=dtype, xp=xp, seed=seed) if f_name in {'broadcast_to', 'squeeze'}: original_shape = marrays[0].shape @@ -723,7 +723,7 @@ def test_manipulation(f_name, n_arrays, n_dims, args, kwargs, dtype, xp, seed=No @pytest.mark.parametrize('copy', [False, True]) @pytest.mark.parametrize('xp', xps) def test_astype(dtype_in, dtype_out, copy, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype_in, xp=xp, seed=seed) res = mxp.astype(marrays[0], getattr(xp, dtype_out), copy=copy) @@ -742,7 +742,7 @@ def test_astype(dtype_in, dtype_out, copy, xp, seed=None): @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=["Only real numeric dtypes are allowed"]) def test_clip(dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(3, dtype=dtype, xp=xp, seed=seed) min = mxp.minimum(marrays[1], marrays[2]) max = mxp.maximum(marrays[1], marrays[2]) @@ -759,7 +759,7 @@ def test_clip(dtype, xp, seed=None): @pytest.mark.parametrize('dtype', dtypes_real + dtypes_int) @pytest.mark.parametrize('xp', xps) def test_sorting(f_name, descending, stable, dtype, xp, seed=None): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) f_mxp = getattr(mxp, f_name) f_xp = getattr(np.ma, f_name) @@ -791,13 +791,13 @@ def test_sorting(f_name, descending, stable, dtype, xp, seed=None): res = mxp.asarray(res_data, mask=res_mask) ref_data = np.take_along_axis(masked_arrays[0].data, ref, axis=-1) ref_mask = np.take_along_axis(masked_arrays[0].mask, ref, axis=-1) - ref = np.ma.MaskedArray(ref_data, mask=ref_mask) + ref = np.ma.masked_array(ref_data, mask=ref_mask) assert_equal(res, ref, xp=xp, seed=seed) @pytest.mark.parametrize('xp', xps) def test_import(xp): - mxp = marray.masked_array(xp) + mxp = marray.get_namespace(xp) from mxp import asarray asarray(10, mask=True)