Skip to content

Commit

Permalink
TST: Adding a test that MaskedArrays respect ufunc deferral heirarchy
Browse files Browse the repository at this point in the history
This test makes sure that a MaskedArray defers properly to another
class if it doesn't know how to handle it. See numpy#15200.
  • Loading branch information
greglucas authored and evinism committed Jul 17, 2022
1 parent bebfff1 commit f14e30e
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions numpy/ma/tests/test_subclassing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import numpy as np
from numpy.lib.mixins import NDArrayOperatorsMixin
from numpy.testing import assert_, assert_raises
from numpy.ma.testutils import assert_equal
from numpy.ma.core import (
Expand Down Expand Up @@ -147,6 +148,33 @@ def __array_wrap__(self, obj, context=None):
return obj


class WrappedArray(NDArrayOperatorsMixin):
"""
Wrapping a MaskedArray rather than subclassing to test that
ufunc deferrals are commutative.
See: https://github.com/numpy/numpy/issues/15200)
"""
__array_priority__ = 20

def __init__(self, array, **attrs):
self._array = array
self.attrs = attrs

def __repr__(self):
return f"{self.__class__.__name__}(\n{self._array}\n{self.attrs}\n)"

def __array__(self):
return np.asarray(self._array)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == '__call__':
inputs = [arg._array if isinstance(arg, self.__class__) else arg
for arg in inputs]
return self.__class__(ufunc(*inputs, **kwargs), **self.attrs)
else:
return NotImplemented


class TestSubclassing:
# Test suite for masked subclasses of ndarray.

Expand Down Expand Up @@ -384,3 +412,39 @@ def test_array_no_inheritance():
# Test that the mask is False and not shared when keep_mask=False
assert_(not new_array.mask)
assert_(not new_array.sharedmask)


class TestClassWrapping:
# Test suite for classes that wrap MaskedArrays

def setup(self):
m = np.ma.masked_array([1, 3, 5], mask=[False, True, False])
wm = WrappedArray(m)
self.data = (m, wm)

def test_masked_unary_operations(self):
# Tests masked_unary_operation
(m, wm) = self.data
with np.errstate(divide='ignore'):
assert_(isinstance(np.log(wm), WrappedArray))

def test_masked_binary_operations(self):
# Tests masked_binary_operation
(m, wm) = self.data
# Result should be a WrappedArray
assert_(isinstance(np.add(wm, wm), WrappedArray))
assert_(isinstance(np.add(m, wm), WrappedArray))
assert_(isinstance(np.add(wm, m), WrappedArray))
# add and '+' should call the same ufunc
assert_equal(np.add(m, wm), m + wm)
assert_(isinstance(np.hypot(m, wm), WrappedArray))
assert_(isinstance(np.hypot(wm, m), WrappedArray))
# Test domained binary operations
assert_(isinstance(np.divide(wm, m), WrappedArray))
assert_(isinstance(np.divide(m, wm), WrappedArray))
assert_equal(np.divide(wm, m) * m, np.divide(m, m) * wm)
# Test broadcasting
m2 = np.stack([m, m])
assert_(isinstance(np.divide(wm, m2), WrappedArray))
assert_(isinstance(np.divide(m2, wm), WrappedArray))
assert_equal(np.divide(m2, wm), np.divide(wm, m2))

0 comments on commit f14e30e

Please sign in to comment.