From 9716a82dd0d583d670b8fd1810713fdf26f7c6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Thu, 8 Jul 2021 19:55:23 +0200 Subject: [PATCH] add getmask and setmask --- terracotta/expressions.py | 5 ++++- tests/test_expressions.py | 25 +++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/terracotta/expressions.py b/terracotta/expressions.py index 103cf083..5b84f3e8 100644 --- a/terracotta/expressions.py +++ b/terracotta/expressions.py @@ -14,8 +14,10 @@ EXTRA_CALLABLES = { # 'name': (callable, nargs) - # boolean ops + # mask operations 'where': (np.ma.where, 3), + 'getmask': (np.ma.getmaskarray, 1), + 'setmask': (lambda arr, mask: np.ma.masked_array(arr, mask=mask), 2), 'masked_equal': (np.ma.masked_equal, 2), 'masked_greater': (np.ma.masked_greater, 2), 'masked_greater_equal': (np.ma.masked_greater_equal, 2), @@ -54,6 +56,7 @@ 'pi': np.pi, 'nan': np.nan, 'inf': np.inf, + 'nomask': np.ma.nomask, } diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1f6e3fe4..bbb18ce7 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -7,8 +7,8 @@ OPERANDS = { - 'v1': np.random.rand(5), - 'v2': 2 * np.random.rand(5) + 'v1': np.ma.masked_array(np.arange(1, 6), dtype='float64'), + 'v2': np.ma.masked_array(2 * np.arange(1, 6), dtype='float64', mask=np.array([1, 1, 1, 0, 0])), } @@ -87,6 +87,23 @@ 'sin(pi * v1)', np.sin(np.pi * OPERANDS['v1']) ), + # mask operations + ( + 'setmask(v1, getmask(v2))', np.ma.masked_array(OPERANDS['v1'], mask=OPERANDS['v2'].mask) + ), + + ( + 'setmask(v2, nomask)', np.ma.masked_array(OPERANDS['v2'], mask=np.ma.nomask) + ), + + ( # replaces mask + 'setmask(v2, ~getmask(v2))', np.ma.masked_array(OPERANDS['v2'], mask=~OPERANDS['v2'].mask) + ), + + ( # adds to mask + 'masked_where(~getmask(v2), v2)', np.ma.masked_array(OPERANDS['v2'], mask=True) + ), + # long expression ( '+'.join(['v1'] * 1000), sum(OPERANDS['v1'] for _ in range(1000)) @@ -230,8 +247,8 @@ def test_timeout(): def test_mask_invalid(): from terracotta.expressions import evaluate_expression - res = evaluate_expression('where(v1 + v2 < 1, nan, 0)', OPERANDS) - mask = OPERANDS['v1'] + OPERANDS['v2'] < 1 + res = evaluate_expression('where(v1 + v2 < 10, nan, 0)', OPERANDS) + mask = (OPERANDS['v1'] + OPERANDS['v2'] < 10) | OPERANDS['v1'].mask | OPERANDS['v2'].mask assert isinstance(res, np.ma.MaskedArray) assert np.all(res == 0)