Skip to content

Commit

Permalink
MAINT: argmin/argmax: fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Nov 25, 2024
1 parent 44d46ed commit 1c111f7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def irshift(x, y): x >>= y
'logaddexp', 'logical_and', 'logical_or', 'logical_xor',
'maximum', 'minimum', 'multiply', 'not_equal', 'pow',
'remainder', 'subtract']
searching_array = ['argmax', 'argmin'] # NumPy masked array funcs not good references
searching_array = ['argmax', 'argmin']
statistical_array = ['cumulative_sum', 'max', 'mean',
'min', 'prod', 'std', 'sum', 'var']
utility_array = ['all', 'any']
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_elementwise_binary(f_name, xp=np, dtype='float64', seed=None):


@pytest.mark.parametrize("keepdims", [False, True])
@pytest.mark.parametrize("f_name", statistical_array + utility_array)
@pytest.mark.parametrize("f_name", statistical_array + utility_array + searching_array)
def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None):
# TODO: confirm that result should never have mask? Only when all are masked?
mxp = marray.masked_array(xp)
Expand All @@ -410,7 +410,9 @@ def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None):
f2 = getattr(xp, f_name2)
res = f(marrays[0], axis=axis, **kwargs)
ref = f2(masked_arrays[0], axis=axis, **kwargs)
ref = np.ma.masked_array(ref.data, getattr(ref, 'mask', False))
# `argmin`/`argmax` don't calculate mask correctly
ref_mask = np.all(masked_arrays[0].mask, axis=axis, **kwargs)
ref = np.ma.masked_array(ref.data, getattr(ref, 'mask', ref_mask))
assert_equal(res, ref, seed)

# Test Linear Algebra functions
Expand All @@ -423,8 +425,6 @@ def test_statistical_array(f_name, keepdims, xp=np, dtype='float64', seed=None):
# Manipulation functions (apply to data and mask separately)

#?
# Searching functions - would test argmin/argmax with statistical functions,
# but NumPy masked version isn't correct
# Set functions
# Sorting functions
# __array_namespace__
Expand Down

0 comments on commit 1c111f7

Please sign in to comment.