diff --git a/marray/__init__.py b/marray/__init__.py index 5ecaee9..d98ed8a 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -43,6 +43,7 @@ def __init__(self, data, mask=None): self._xp = xp self._sentinel = (info(self).max if not xp.isdtype(self.dtype, 'bool') else None) + self.__array_namespace__ = mod @property def data(self): @@ -172,7 +173,6 @@ def fun(self, other, name=name, **kwargs): setattr(MaskedArray, name, fun) # To be added - # __array_namespace__ # __dlpack__, __dlpack_device__ # to_device? diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index 661485e..eac4a5d 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -146,7 +146,7 @@ def irshift(x, y): x >>= y 'logaddexp', 'logical_and', 'logical_or', 'logical_xor', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract'] -searching_array = ['argmax', 'argmin'] +searching_array = ['argmax', 'argmin'] # NumPy masked array funcs not good references statistical_array = ['cumulative_sum', 'max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] utility_array = ['all', 'any'] @@ -423,10 +423,13 @@ def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None): # Manipulation functions (apply to data and mask separately) #? -# Searching functions - finish testing argmin/argmax with above +# Searching functions - would test argmin/argmax with statistical functions, +# but NumPy masked version isn't correct # Set functions # Sorting functions +# __array_namespace__ def test_test(): - seed = 8377009968503871097350278305436713931 + seed = 149020664425889521094089537542803361848 + # test_statistical_array('argmin', True, seed=seed) test_rarithmetic_binary(arithmetic_binary[0], 'float32', seed=seed)