Skip to content

Commit

Permalink
Merge pull request #13 from mdhaber/fix_tests
Browse files Browse the repository at this point in the history
TST: fix tests that were commented out
  • Loading branch information
mdhaber authored Nov 22, 2024
2 parents b5da7a5 + cf88711 commit b2cea15
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 111 deletions.
53 changes: 33 additions & 20 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ def masked_array(xp):
class MaskedArray:

def __init__(self, data, mask=None):
data = getattr(data, '_data', data)
mask = (xp.zeros(data.shape, dtype=bool) if mask is None
else xp.asarray(mask, dtype=bool))
data = xp.asarray(getattr(data, '_data', data))
mask = (xp.zeros(data.shape, dtype=xp.bool) if mask is None
else xp.asarray(mask, dtype=xp.bool))
mask = xp.asarray(xp.broadcast_to(mask, data.shape), copy=True)
self._data = data
self._dtype = data.dtype
self._device = data.device
# assert data.device == mask.device
self._ndim = data.ndim
self._shape = data.shape
self._size = data.size
Expand All @@ -50,6 +52,10 @@ def data(self):
def dtype(self):
return self._dtype

@property
def device(self):
return self._device

@property
def ndim(self):
return self._ndim
Expand Down Expand Up @@ -98,8 +104,15 @@ def __str__(self):
def __matmul__(self, other):
return mod.matmul(self, other)

def __imatmul__(self, other):
res = mod.matmul(self, other)
self.data[...] = res.data[...]
self.mask[...] = res.mask[...]
return

def __rmatmul__(self, other):
return mod.matmul(self, other)
other = MaskedArray(other)
return mod.matmul(other, self)

## Attributes ##

Expand Down Expand Up @@ -137,8 +150,7 @@ def fun(self, name=name):
# Methods that return the result of an elementwise binary operation (reflected)
rbinary_names = ['__radd__', '__rand__', '__rdivmod__', '__rfloordiv__',
'__rlshift__', '__rmod__', '__rmul__', '__ror__', '__rpow__',
'__rrshift__', '__rshift__', '__rsub__', '__rtruediv__',
'__rxor__']
'__rrshift__', '__rsub__', '__rtruediv__', '__rxor__']
for name in binary_names + rbinary_names:
def fun(self, other, name=name):
mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask
Expand All @@ -159,8 +171,6 @@ def fun(self, other, name=name, **kwargs):
return self
setattr(MaskedArray, name, fun)

# Inherited

# To be added
# __array_namespace__
# __dlpack__, __dlpack_device__
Expand Down Expand Up @@ -210,7 +220,10 @@ def fun(*args, name=name, **kwargs):
dtype_fun_names = ['can_cast', 'finfo', 'iinfo', 'isdtype', 'result_type']
dtype_names = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16',
'uint32', 'uint64', 'float32', 'float64', 'complex64', 'complex128']
for name in dtype_fun_names + dtype_names:
inspection_fun_names = ['__array_namespace_info__']
version_attribute_names = ['__array_api_version__']
for name in (dtype_fun_names + dtype_names + inspection_fun_names
+ version_attribute_names):
setattr(mod, name, getattr(xp, name))

mod.astype = (lambda x, dtype, /, *, copy=True, **kwargs:
Expand Down Expand Up @@ -267,13 +280,7 @@ def xp_take_along_axis(arr, indices, axis):
mod._xp_take_along_axis = xp_take_along_axis

## Inspection ##
# To be written
# __array_namespace_info
# capabilities
# default_device
# default_dtypes
# devices
# dtypes
# Included with dtype functions above

## Linear Algebra Functions ##
def get_linalg_fun(name):
Expand All @@ -286,14 +293,20 @@ def linalg_fun(x1, x2, /, **kwargs):
data2[x2.mask] = 0
fun = getattr(xp, name)
data = fun(data1, data2)
mask = ~fun(~x1.mask, ~x2.mask)
# Strict array can't do arithmetic with booleans
# mask = ~fun(~x1.mask, ~x2.mask)
mask = fun(xp.astype(~x1.mask, xp.uint64),
xp.astype(~x2.mask, xp.uint64))
mask = ~xp.astype(mask, xp.bool)
return MaskedArray(data, mask)
return linalg_fun

linalg_names = ['matmul', 'tensordot', 'vecdot']
for name in linalg_names:
setattr(mod, name, get_linalg_fun(name))

mod.matrix_transpose = lambda x: x.mT

## Manipulation Functions ##
first_arg_arrays = {'broadcast_arrays', 'concat', 'stack'}
output_arrays = {'broadcast_arrays', 'unstack'}
Expand Down Expand Up @@ -413,13 +426,13 @@ def statistical_fun(x, *args, axis=None, name=name, **kwargs):
data[x.mask] = replacements[name]
fun = getattr(xp, name)
res = fun(data, *args, axis=axis, **kwargs)
mask = xp.all(x.mask, axis=axis)
mask = xp.all(x.mask, axis=axis, keepdims=kwargs.get('keepdims', False))
return MaskedArray(res, mask=mask)
return statistical_fun

def count(x, axis=None, keepdims=False):
x = asarray(x)
return xp.sum(x.mask, axis=axis, keepdims=keepdims)
return xp.sum(~x.mask, axis=axis, keepdims=keepdims)

def cumulative_sum(x, *args, **kwargs):
x = asarray(x)
Expand All @@ -444,7 +457,7 @@ def var(x, axis=None, correction=0, keepdims=False):
mod.count = count
mod.mean = mean
mod.var = var
mod.std = lambda *args, **kwargs: np.sqrt(mod.var(*args, **kwargs))
mod.std = lambda *args, **kwargs: mod.var(*args, **kwargs)**0.5

search_names = ['argmax', 'argmin']
statfun_names = ['max', 'min', 'sum', 'prod']
Expand Down
Loading

0 comments on commit b2cea15

Please sign in to comment.