Skip to content

Commit

Permalink
Merge branch 'main' into test_indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber authored Dec 1, 2024
2 parents 185eddd + ee51465 commit 715fb2b
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 68 deletions.
113 changes: 59 additions & 54 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

__version__ = "0.0.4"

import numpy as np # temporarily used in __repr__ and __str__

import dataclasses

def masked_array(xp):
"""Returns a masked array namespace for an array API backend
Expand All @@ -30,7 +29,8 @@ def __init__(self, data, mask=None):
data = xp.asarray(getattr(data, '_data', data))
mask = (xp.zeros(data.shape, dtype=xp.bool) if mask is None
else xp.asarray(mask, dtype=xp.bool))
mask = xp.asarray(xp.broadcast_to(mask, data.shape), copy=True)
if mask.shape != data.shape: # avoid copy if possible
mask = xp.asarray(xp.broadcast_to(mask, data.shape), copy=True)
self._data = data
self._dtype = data.dtype
self._device = data.device
Expand Down Expand Up @@ -86,20 +86,20 @@ def __setitem__(self, key, other):
self.mask[key] = getattr(other, 'mask', False)
return self.data.__setitem__(key, getattr(other, 'data', other))

def _data_mask_string(self, fun):
data_str = fun(self.data)
mask_str = fun(self.mask)
if len(data_str) + len(mask_str) <= 66:
return f"MaskedArray({data_str}, {mask_str})"
else:
return f"MaskedArray(\n {data_str},\n {mask_str}\n)"

## Visualization ##
def __repr__(self):
# temporary: fix for CuPy
# eventually: rewrite to avoid masked array
data = np.asarray(self.data)
mask = np.asarray(self.mask)
return np.ma.masked_array(data, mask).__repr__()
return self._data_mask_string(repr)

def __str__(self):
# temporary: fix for CuPy
# eventually: rewrite to avoid masked array
data = np.asarray(self.data)
mask = np.asarray(self.mask)
return np.ma.masked_array(data, mask).__str__()
return self._data_mask_string(str)

## Linear Algebra Methods ##
def __matmul__(self, other):
Expand Down Expand Up @@ -189,6 +189,9 @@ def info(x):
xp = x._xp
if xp.isdtype(x.dtype, 'integral'):
return xp.iinfo(x.dtype)
elif xp.isdtype(x.dtype, 'bool'):
binfo = dataclasses.make_dataclass("binfo", ['min', 'max'])
return binfo(min=False, max=True)
else:
return xp.finfo(x.dtype)

Expand All @@ -197,6 +200,8 @@ class module:

mod = module()

mod.MaskedArray = MaskedArray

## Constants ##
constant_names = ['e', 'inf', 'nan', 'newaxis', 'pi']
for name in constant_names:
Expand All @@ -212,14 +217,16 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):

mask = (getattr(obj, 'mask', xp.full(data.shape, False))
if mask is None else mask)
mask = xp.asarray(mask, dtype=dtype, device=device, copy=copy)
mask = xp.asarray(mask, dtype=xp.bool, device=device, copy=copy)

return MaskedArray(data, mask=mask)
mod.asarray = asarray

creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack',
'full', 'full_like', 'linspace', 'meshgrid', 'ones',
'ones_like', 'tril', 'triu', 'zeros', 'zeros_like']
'full', 'full_like', 'linspace', 'ones', 'ones_like',
'zeros', 'zeros_like']
# handled with array manipulation functions
creation_manip_functions = ['tril', 'triu', 'meshgrid']
for name in creation_functions:
def fun(*args, name=name, **kwargs):
data = getattr(xp, name)(*args, **kwargs)
Expand All @@ -243,7 +250,7 @@ def fun(*args, name=name, **kwargs):
elementwise_names = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_or', 'bitwise_right_shift',
'bitwise_xor', 'ceil', 'clip', 'conj', 'copysign', 'cos',
'bitwise_xor', 'ceil', 'conj', 'copysign', 'cos',
'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor',
'floor_divide', 'greater', 'greater_equal', 'hypot',
'imag', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal',
Expand All @@ -262,32 +269,25 @@ def fun(*args, name=name, **kwargs):
return MaskedArray(out, mask=xp.any(masks, axis=0))
setattr(mod, name, fun)

## Indexing Functions
# To be written:
# take

def xp_take_along_axis(arr, indices, axis):
# This is just for regular arrays; not masked arrays
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)

m = arr.shape[-1]
n = indices.shape[-1]

shape = list(arr.shape)
shape.pop(-1)
shape = shape + [n,]

arr = xp.reshape(arr, (-1,))
indices = xp.reshape(indices, (-1, n))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))
def clip(x, /, min=None, max=None):
args = [x, min, max]
masks = [arg.mask for arg in args if hasattr(arg, 'mask')]
masks = xp.broadcast_arrays(*masks)
mask = xp.any(masks, axis=0)
datas = [getattr(arg, 'data', arg) for arg in args]
data = xp.clip(datas[0], min=datas[1], max=datas[2])
return MaskedArray(data, mask)
mod.clip = clip

out = arr[indices]
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
mod._xp_take_along_axis = xp_take_along_axis
## Indexing Functions
def take(x, indices, /, *, axis=None):
indices_data = getattr(indices, 'data', indices)
indices_mask = getattr(indices, 'mask', False)
data = xp.take(x.data, indices_data, axis=axis)
mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask
return MaskedArray(data, mask=mask)
mod.take = take

## Inspection ##
# Included with dtype functions above
Expand Down Expand Up @@ -318,8 +318,8 @@ def linalg_fun(x1, x2, /, **kwargs):
mod.matrix_transpose = lambda x: x.mT

## Manipulation Functions ##
first_arg_arrays = {'broadcast_arrays', 'concat', 'stack'}
output_arrays = {'broadcast_arrays', 'unstack'}
first_arg_arrays = {'broadcast_arrays', 'concat', 'stack', 'meshgrid'}
output_arrays = {'broadcast_arrays', 'unstack', 'meshgrid'}

def get_manip_fun(name):
def manip_fun(x, *args, **kwargs):
Expand All @@ -332,7 +332,7 @@ def manip_fun(x, *args, **kwargs):

fun = getattr(xp, name)

if name == 'broadcast_arrays':
if name in {'broadcast_arrays', 'meshgrid'}:
res = fun(*data, *args, **kwargs)
mask = fun(*mask, *args, **kwargs)
else:
Expand All @@ -347,20 +347,24 @@ def manip_fun(x, *args, **kwargs):
manip_names = ['broadcast_arrays', 'broadcast_to', 'concat', 'expand_dims',
'flip', 'moveaxis', 'permute_dims', 'repeat', 'reshape',
'roll', 'squeeze', 'stack', 'tile', 'unstack']
for name in manip_names:
for name in manip_names + creation_manip_functions:
setattr(mod, name, get_manip_fun(name))
mod.broadcast_arrays = lambda *arrays: get_manip_fun('broadcast_arrays')(arrays)

# This is just for regular arrays; not masked arrays
def xp_swapaxes(arr, axis1, axis2):
axes = list(range(arr.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
return xp.permute_dims(arr, axes)
mod.xp_swapaxes = xp_swapaxes
mod.meshgrid = lambda *arrays, **kwargs: get_manip_fun('meshgrid')(arrays, **kwargs)

## Searching Functions
# To be added
# searchsorted
def searchsorted(x1, x2, /, *, side='left', sorter=None):
if sorter is not None:
x1 = take(x1, sorter)

mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64))
x1_compressed = x1.data[~x1.mask]
count = xp.zeros(x1_compressed.size+1, dtype=xp.int64)
count[:-1] = mask_count[~x1.mask]
count[-1] = count[-2]
i = xp.searchsorted(x1_compressed, x2.data, side=side)
j = i + xp.take(count, i)
return MaskedArray(j, mask=x2.mask)

def nonzero(x, /):
x = asarray(x)
Expand All @@ -377,6 +381,7 @@ def where(condition, x1, x2, /):
mask = condition.mask | x1.mask | x2.mask
return MaskedArray(data, mask)

mod.searchsorted = searchsorted
mod.nonzero = nonzero
mod.where = where

Expand Down
Loading

0 comments on commit 715fb2b

Please sign in to comment.