Skip to content

Commit

Permalink
TST: parametrize over dtypes, backends
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Dec 2, 2024
1 parent c9c8388 commit 0583d44
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 184 deletions.
22 changes: 14 additions & 8 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, data, mask=None):
self._xp = xp
self._sentinel = (info(self).max if not xp.isdtype(self.dtype, 'bool')
else None)
self.__array_namespace__ = mod

@property
def data(self):
Expand Down Expand Up @@ -242,8 +241,13 @@ def fun(*args, name=name, **kwargs):
+ version_attribute_names):
setattr(mod, name, getattr(xp, name))

mod.astype = (lambda x, dtype, /, *, copy=True, **kwargs:
asarray(x, copy=copy or (dtype != x.dtype), dtype=dtype, **kwargs))
def astype(x, dtype, /, *, copy=True, device=None):
if device is None and not copy and dtype == x.dtype:
return x
data = xp.astype(x.data, dtype, copy=copy, device=device)
mask = xp.astype(x.mask, xp.bool, copy=copy, device=device)
return MaskedArray(data, mask=mask)
mod.astype = astype

## Elementwise Functions ##
elementwise_names = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan',
Expand All @@ -265,15 +269,15 @@ def fun(*args, name=name, **kwargs):
masks = xp.broadcast_arrays(*masks)
args = [getattr(arg, 'data', arg) for arg in args]
out = getattr(xp, name)(*args, **kwargs)
return MaskedArray(out, mask=xp.any(masks, axis=0))
return MaskedArray(out, mask=xp.any(xp.stack(masks), axis=0))
setattr(mod, name, fun)


def clip(x, /, min=None, max=None):
args = [x, min, max]
masks = [arg.mask for arg in args if hasattr(arg, 'mask')]
masks = xp.broadcast_arrays(*masks)
mask = xp.any(masks, axis=0)
mask = xp.any(xp.stack(masks), axis=0)
datas = [getattr(arg, 'data', arg) for arg in args]
data = xp.clip(datas[0], min=datas[1], max=datas[2])
return MaskedArray(data, mask)
Expand Down Expand Up @@ -415,7 +419,8 @@ def sort_fun(x, /, *, axis=-1, descending=False, stable=True):
sentinel = info(x).min if descending else info(x).max
data[x.mask] = sentinel
fun = getattr(xp, name)
res = fun(data, axis=axis, descending=descending, stable=stable)
kwargs = {'descending': True} if descending else {}
res = fun(data, axis=axis, stable=stable, **kwargs)
mask = (res == sentinel) if name=='sort' else None
return MaskedArray(res, mask)
return sort_fun
Expand Down Expand Up @@ -446,7 +451,8 @@ def statistical_fun(x, *args, axis=None, name=name, **kwargs):

def count(x, axis=None, keepdims=False):
x = asarray(x)
return xp.sum(~x.mask, axis=axis, keepdims=keepdims)
not_mask = xp.astype(~x.mask, xp.uint64)
return xp.sum(not_mask, axis=axis, keepdims=keepdims, dtype=xp.uint64)

def cumulative_sum(x, *args, **kwargs):
x = asarray(x)
Expand All @@ -464,8 +470,8 @@ def var(x, axis=None, correction=0, keepdims=False):
# rewrite this to use xp.var but replace masked entries with mean.
m = mod.mean(x, axis=axis, keepdims=True)
xm = x - m
n = mod.count(x, axis=axis, keepdims=keepdims)
s = mod.sum(xm**2, axis=axis, keepdims=keepdims)
n = mod.count(x, axis=axis, keepdims=keepdims)
return s / (n - correction)

mod.count = count
Expand Down
Loading

0 comments on commit 0583d44

Please sign in to comment.