Skip to content

Commit

Permalink
MAINT: rename MaskedArray -> MArray, masked_array -> get_namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Dec 2, 2024
1 parent 0583d44 commit 5511a21
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 66 deletions.
80 changes: 39 additions & 41 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,22 @@
import types, sys
import dataclasses

def masked_array(xp):
"""Returns a masked array namespace for an array API backend
def get_namespace(xp):
"""Returns a masked array namespace for an Array API Standard compatible backend
Examples
--------
>>> from scipy._lib.array_api_compat import numpy as xp
>>> from scipy.stats import masked_array
>>> ma = masked_array(xp)
>>> A = ma.eye(3)
>>> import numpy as xp
>>> from marray import get_namespace
>>> mxp = get_namespace(xp)
>>> A = mxp.eye(3)
>>> A.mask[0, ...] = True
>>> x = ma.asarray([1, 2, 3], mask=[False, False, True])
>>> x = mxp.asarray([1, 2, 3], mask=[False, False, True])
>>> A @ x
masked_array(data=[--, 2.0, 0.0],
mask=[ True, False, False],
fill_value=1e+20)
MArray(array([0., 2., 0.]), array([ True, False, False]))
"""
class MaskedArray:
class MArray:

def __init__(self, data, mask=None):
data = xp.asarray(getattr(data, '_data', data))
Expand Down Expand Up @@ -80,7 +78,7 @@ def call_super_method(self, method_name, *args, **kwargs):

## Indexing ##
def __getitem__(self, key):
return MaskedArray(self.data[key], self.mask[key])
return MArray(self.data[key], self.mask[key])

def __setitem__(self, key, other):
self.mask[key] = getattr(other, 'mask', False)
Expand All @@ -90,9 +88,9 @@ def _data_mask_string(self, fun):
data_str = fun(self.data)
mask_str = fun(self.mask)
if len(data_str) + len(mask_str) <= 66:
return f"MaskedArray({data_str}, {mask_str})"
return f"MArray({data_str}, {mask_str})"
else:
return f"MaskedArray(\n {data_str},\n {mask_str}\n)"
return f"MArray(\n {data_str},\n {mask_str}\n)"

## Visualization ##
def __repr__(self):
Expand All @@ -112,18 +110,18 @@ def __imatmul__(self, other):
return

def __rmatmul__(self, other):
other = MaskedArray(other)
other = MArray(other)
return mod.matmul(other, self)

## Attributes ##

@property
def T(self):
return MaskedArray(self.data.T, self.mask.T)
return MArray(self.data.T, self.mask.T)

@property
def mT(self):
return MaskedArray(self.data.mT, self.mask.mT)
return MArray(self.data.mT, self.mask.mT)

# dlpack
def __dlpack_device__(self):
Expand All @@ -146,15 +144,15 @@ def to_device(self, device, /, *, stream=None):
for name in unary_names:
def fun(self, name=name):
data = self.call_super_method(name)
return MaskedArray(data, self.mask)
setattr(MaskedArray, name, fun)
return MArray(data, self.mask)
setattr(MArray, name, fun)

# Methods that return the result of a unary operation as a Python scalar
unary_names_py = ['__bool__', '__complex__', '__float__', '__index__', '__int__']
for name in unary_names_py:
def fun(self, name=name):
return self.call_super_method(name)
setattr(MaskedArray, name, fun)
setattr(MArray, name, fun)

# Methods that return the result of an elementwise binary operation
binary_names = ['__add__', '__sub__', '__and__', '__eq__', '__ge__', '__gt__',
Expand All @@ -169,8 +167,8 @@ def fun(self, name=name):
def fun(self, other, name=name):
mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask
data = self.call_super_method(name, other)
return MaskedArray(data, mask)
setattr(MaskedArray, name, fun)
return MArray(data, mask)
setattr(MArray, name, fun)

# In-place methods
desired_names = ['__iadd__', '__iand__', '__ifloordiv__', '__ilshift__',
Expand All @@ -183,7 +181,7 @@ def fun(self, other, name=name, **kwargs):
self.mask.__ior__(other.mask)
self.call_super_method(name, other)
return self
setattr(MaskedArray, name, fun)
setattr(MArray, name, fun)

def info(x):
xp = x._xp
Expand All @@ -198,7 +196,7 @@ def info(x):
mod = types.ModuleType('mxp')
sys.modules['mxp'] = mod

mod.MaskedArray = MaskedArray
mod.MArray = MArray

## Constants ##
constant_names = ['e', 'inf', 'nan', 'newaxis', 'pi']
Expand All @@ -217,7 +215,7 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):
if mask is None else mask)
mask = xp.asarray(mask, dtype=xp.bool, device=device, copy=copy)

return MaskedArray(data, mask=mask)
return MArray(data, mask=mask)
mod.asarray = asarray

creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack',
Expand All @@ -228,7 +226,7 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):
for name in creation_functions:
def fun(*args, name=name, **kwargs):
data = getattr(xp, name)(*args, **kwargs)
return MaskedArray(data)
return MArray(data)
setattr(mod, name, fun)

## Data Type Functions and Data Types ##
Expand All @@ -246,7 +244,7 @@ def astype(x, dtype, /, *, copy=True, device=None):
return x
data = xp.astype(x.data, dtype, copy=copy, device=device)
mask = xp.astype(x.mask, xp.bool, copy=copy, device=device)
return MaskedArray(data, mask=mask)
return MArray(data, mask=mask)
mod.astype = astype

## Elementwise Functions ##
Expand All @@ -269,7 +267,7 @@ def fun(*args, name=name, **kwargs):
masks = xp.broadcast_arrays(*masks)
args = [getattr(arg, 'data', arg) for arg in args]
out = getattr(xp, name)(*args, **kwargs)
return MaskedArray(out, mask=xp.any(xp.stack(masks), axis=0))
return MArray(out, mask=xp.any(xp.stack(masks), axis=0))
setattr(mod, name, fun)


Expand All @@ -280,7 +278,7 @@ def clip(x, /, min=None, max=None):
mask = xp.any(xp.stack(masks), axis=0)
datas = [getattr(arg, 'data', arg) for arg in args]
data = xp.clip(datas[0], min=datas[1], max=datas[2])
return MaskedArray(data, mask)
return MArray(data, mask)
mod.clip = clip

## Indexing Functions
Expand All @@ -289,7 +287,7 @@ def take(x, indices, /, *, axis=None):
indices_mask = getattr(indices, 'mask', False)
data = xp.take(x.data, indices_data, axis=axis)
mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask
return MaskedArray(data, mask=mask)
return MArray(data, mask=mask)
mod.take = take

## Inspection ##
Expand All @@ -311,7 +309,7 @@ def linalg_fun(x1, x2, /, **kwargs):
mask = fun(xp.astype(~x1.mask, xp.uint64),
xp.astype(~x2.mask, xp.uint64))
mask = ~xp.astype(mask, xp.bool)
return MaskedArray(data, mask)
return MArray(data, mask)
return linalg_fun

linalg_names = ['matmul', 'tensordot', 'vecdot']
Expand Down Expand Up @@ -342,8 +340,8 @@ def manip_fun(x, *args, **kwargs):
res = fun(data, *args, **kwargs)
mask = fun(mask, *args, **kwargs)

out = (MaskedArray(res, mask) if name not in output_arrays
else [MaskedArray(resi, maski) for resi, maski in zip(res, mask)])
out = (MArray(res, mask) if name not in output_arrays
else [MArray(resi, maski) for resi, maski in zip(res, mask)])
return out
return manip_fun

Expand All @@ -367,22 +365,22 @@ def searchsorted(x1, x2, /, *, side='left', sorter=None):
count[-1] = count[-2]
i = xp.searchsorted(x1_compressed, x2.data, side=side)
j = i + xp.take(count, i)
return MaskedArray(j, mask=x2.mask)
return MArray(j, mask=x2.mask)

def nonzero(x, /):
x = asarray(x)
data = xp.asarray(x.data, copy=True)
data[x.mask] = 0
res = xp.nonzero(data)
return tuple(MaskedArray(resi) for resi in res)
return tuple(MArray(resi) for resi in res)

def where(condition, x1, x2, /):
condition = asarray(condition)
x1 = asarray(x1)
x2 = asarray(x2)
data = xp.where(condition.data, x1.data, x2.data)
mask = condition.mask | x1.mask | x2.mask
return MaskedArray(data, mask)
return MArray(data, mask)

mod.searchsorted = searchsorted
mod.nonzero = nonzero
Expand All @@ -403,8 +401,8 @@ def set_fun(x, /):
fun = getattr(xp, name)
res = fun(data)
# this sort of works but could be refined
return (MaskedArray(res, res==x._sentinel) if name=='unique_values'
else tuple(MaskedArray(resi, resi==x._sentinel) for resi in res))
return (MArray(res, res==x._sentinel) if name=='unique_values'
else tuple(MArray(resi, resi==x._sentinel) for resi in res))
return set_fun

unique_names = ['unique_values', 'unique_counts', 'unique_inverse', 'unique_all']
Expand All @@ -422,7 +420,7 @@ def sort_fun(x, /, *, axis=-1, descending=False, stable=True):
kwargs = {'descending': True} if descending else {}
res = fun(data, axis=axis, stable=stable, **kwargs)
mask = (res == sentinel) if name=='sort' else None
return MaskedArray(res, mask)
return MArray(res, mask)
return sort_fun

sort_names = ['sort', 'argsort']
Expand All @@ -446,7 +444,7 @@ def statistical_fun(x, *args, axis=None, name=name, **kwargs):
fun = getattr(xp, name)
res = fun(data, *args, axis=axis, **kwargs)
mask = xp.all(x.mask, axis=axis, keepdims=kwargs.get('keepdims', False))
return MaskedArray(res, mask=mask)
return MArray(res, mask=mask)
return statistical_fun

def count(x, axis=None, keepdims=False):
Expand All @@ -459,7 +457,7 @@ def cumulative_sum(x, *args, **kwargs):
data = xp.asarray(x.data, copy=True)
data[x.mask] = 0
res = xp.cumulative_sum(data, *args, **kwargs)
return MaskedArray(res, x.mask)
return MArray(res, x.mask)

def mean(x, axis=None, keepdims=False):
s = mod.sum(x, axis=axis, keepdims=keepdims)
Expand Down
Loading

0 comments on commit 5511a21

Please sign in to comment.