Skip to content

Commit

Permalink
add getmask and setmask
Browse files Browse the repository at this point in the history
  • Loading branch information
dionhaefner committed Jul 8, 2021
1 parent 10e30f3 commit 9716a82
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
5 changes: 4 additions & 1 deletion terracotta/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
EXTRA_CALLABLES = {
# 'name': (callable, nargs)

# boolean ops
# mask operations
'where': (np.ma.where, 3),
'getmask': (np.ma.getmaskarray, 1),
'setmask': (lambda arr, mask: np.ma.masked_array(arr, mask=mask), 2),
'masked_equal': (np.ma.masked_equal, 2),
'masked_greater': (np.ma.masked_greater, 2),
'masked_greater_equal': (np.ma.masked_greater_equal, 2),
Expand Down Expand Up @@ -54,6 +56,7 @@
'pi': np.pi,
'nan': np.nan,
'inf': np.inf,
'nomask': np.ma.nomask,
}


Expand Down
25 changes: 21 additions & 4 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


OPERANDS = {
'v1': np.random.rand(5),
'v2': 2 * np.random.rand(5)
'v1': np.ma.masked_array(np.arange(1, 6), dtype='float64'),
'v2': np.ma.masked_array(2 * np.arange(1, 6), dtype='float64', mask=np.array([1, 1, 1, 0, 0])),
}


Expand Down Expand Up @@ -87,6 +87,23 @@
'sin(pi * v1)', np.sin(np.pi * OPERANDS['v1'])
),

# mask operations
(
'setmask(v1, getmask(v2))', np.ma.masked_array(OPERANDS['v1'], mask=OPERANDS['v2'].mask)
),

(
'setmask(v2, nomask)', np.ma.masked_array(OPERANDS['v2'], mask=np.ma.nomask)
),

( # replaces mask
'setmask(v2, ~getmask(v2))', np.ma.masked_array(OPERANDS['v2'], mask=~OPERANDS['v2'].mask)
),

( # adds to mask
'masked_where(~getmask(v2), v2)', np.ma.masked_array(OPERANDS['v2'], mask=True)
),

# long expression
(
'+'.join(['v1'] * 1000), sum(OPERANDS['v1'] for _ in range(1000))
Expand Down Expand Up @@ -230,8 +247,8 @@ def test_timeout():

def test_mask_invalid():
from terracotta.expressions import evaluate_expression
res = evaluate_expression('where(v1 + v2 < 1, nan, 0)', OPERANDS)
mask = OPERANDS['v1'] + OPERANDS['v2'] < 1
res = evaluate_expression('where(v1 + v2 < 10, nan, 0)', OPERANDS)
mask = (OPERANDS['v1'] + OPERANDS['v2'] < 10) | OPERANDS['v1'].mask | OPERANDS['v2'].mask

assert isinstance(res, np.ma.MaskedArray)
assert np.all(res == 0)
Expand Down

0 comments on commit 9716a82

Please sign in to comment.