Skip to content

Commit

Permalink
TST: test scalar conversions, indexing, and dlpack methods (#26)
Browse files Browse the repository at this point in the history
* TST: test indexing and scalar conversion
* TST: add/test dlpack methods
  • Loading branch information
mdhaber authored Dec 1, 2024
1 parent ee51465 commit f518a0f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
17 changes: 13 additions & 4 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ def T(self):
def mT(self):
return MaskedArray(self.data.mT, self.mask.mT)

# dlpack
def __dlpack_device__(self):
return self.data.__dlpack_device__()

def __dlpack__(self):
# really not sure how to define this
return self.data.__dlpack__()

def to_device(self, device, /, *, stream=None):
self._data = self._data.to_device(device, stream=stream)
self._mask = self._mask.to_device(device, stream=stream)


## Methods ##

# Methods that return the result of a unary operation as an array
Expand Down Expand Up @@ -172,10 +185,6 @@ def fun(self, other, name=name, **kwargs):
return self
setattr(MaskedArray, name, fun)

# To be added
# __dlpack__, __dlpack_device__
# to_device?

def info(x):
xp = x._xp
if xp.isdtype(x.dtype, 'integral'):
Expand Down
69 changes: 58 additions & 11 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,13 @@ def assert_allclose(res, ref, seed, **kwargs):
'bitwise_right_shift': lambda x, y: x.__rshift__(y),
'bitwise_xor': lambda x, y: x.__xor__(y)}

# __array_namespace__
# __bool__
# __complex__
# __dlpack__
# __dlpack_device__
# __float__
# __getitem__
# __index__
# __int__
# __setitem__
# __to_device__

scalar_conversions = {bool: True, int: 10, float: 1.5, complex: 1.5 + 2.5j}

# tested in test_dlpack
# __dlpack__, __dlpack_device__, to_device
# tested in test_indexing
# __getitem__, __index__, __setitem__,

comparison_binary = [lambda x, y: x < y, lambda x, y: x <= y, lambda x, y: x > y,
lambda x, y: x >= y, lambda x, y: x == y , lambda x, y: x != y]
Expand Down Expand Up @@ -223,6 +219,57 @@ def test_bitwise_binary(f_name_fun, dtype, xp=np, seed=None):
assert_equal(res, ref, seed)


@pytest.mark.parametrize('type_val', scalar_conversions.items())
@pytest.mark.parametrize('mask', [False, True])
def test_scalar_conversion(type_val, mask, xp=np):
mxp = marray.masked_array(xp)
type, val = type_val
x = mxp.asarray(val)
assert type(x) == val
assert isinstance(type(x), type)

method = getattr(x, f"__{type.__name__}__")
assert method() == val
assert isinstance(method(), type)


def test_indexing(xp=strict):
# The implementations of `__getitem__` and `__setitem__` are trivial.
# This does not make them easy to test exhaustively, but it does make
# them easy to fix if a shortcoming is identified. Include a very basic
# test for now, and improve as needed.
mxp = marray.masked_array(xp)
x = mxp.asarray(xp.arange(3), mask=[False, True, False])

# Test `__setitem__`/`__getitem__` roundtrip
x[1] = 10
assert x[1] == 10
assert isinstance(x[1], type(x))

# Test `__setitem__`/`__getitem__` roundtrip with masked array as index
i = mxp.asarray(1, mask=True)
x[i.__index__()] = 20
assert x[i.__index__()] == 20
assert isinstance(x[i.__index__()], type(x))

# `__setitem__` can change mask
x[1] = mxp.asarray(30, mask=False)
assert x[1].data == 30
assert x[1].mask == False
x[2] = mxp.asarray(40, mask=True)
assert x[2].data == 40
assert x[2].mask == True


def test_dlpack(xp=strict, seed=None):
# This is a placeholder for a real test when there is a real implementation
mxp = marray.masked_array(xp)
marrays, _, seed = get_arrays(1, seed=seed)
assert isinstance(marrays[0].__dlpack__(), type(marrays[0].data.__dlpack__()))
assert marrays[0].__dlpack_device__() == marrays[0].data.__dlpack_device__()
marrays[0].to_device('cpu')


@pytest.mark.parametrize("dtype", dtypes_integral + dtypes_real)
@pytest.mark.parametrize("f", comparison_binary + comparison_methods_binary)
def test_comparison_binary(f, dtype, seed=None):
Expand Down

0 comments on commit f518a0f

Please sign in to comment.