Skip to content

Commit

Permalink
TST: test/fix clip
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Nov 24, 2024
1 parent 8ffe94e commit cd0fcf7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
13 changes: 12 additions & 1 deletion marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def fun(*args, name=name, **kwargs):
elementwise_names = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_or', 'bitwise_right_shift',
'bitwise_xor', 'ceil', 'clip', 'conj', 'copysign', 'cos',
'bitwise_xor', 'ceil', 'conj', 'copysign', 'cos',
'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor',
'floor_divide', 'greater', 'greater_equal', 'hypot',
'imag', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal',
Expand All @@ -253,6 +253,17 @@ def fun(*args, name=name, **kwargs):
return MaskedArray(out, mask=xp.any(masks, axis=0))
setattr(mod, name, fun)


def clip(x, /, min=None, max=None):
args = [x, min, max]
masks = [arg.mask for arg in args if hasattr(arg, 'mask')]
masks = xp.broadcast_arrays(*masks)
mask = xp.any(masks, axis=0)
datas = [getattr(arg, 'data', arg) for arg in args]
data = xp.clip(datas[0], min=datas[1], max=datas[2])
return MaskedArray(data, mask)
mod.clip = clip

## Indexing Functions
# To be written:
# take
Expand Down
8 changes: 8 additions & 0 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ def test_astype(dtype_in, dtype_out, copy, xp=np, seed=None):
assert_equal(res, ref, seed)


@pytest.mark.parametrize('dtype', dtypes_real)
def test_clip(dtype, xp=np, seed=None):
mxp = marray.masked_array(xp)
marrays, masked_arrays, seed = get_arrays(3, dtype=dtype, seed=seed)
res = mxp.clip(marrays[0], min=marrays[1], max=marrays[2])
ref = np.ma.clip(*masked_arrays)
assert_equal(res, ref, seed)

#?
# Searching functions - would test argmin/argmax with statistical functions,
# but NumPy masked version isn't correct
Expand Down

0 comments on commit cd0fcf7

Please sign in to comment.