Skip to content

Commit

Permalink
Merge pull request #1935 from rpmanser/dask_tests_basic
Browse files Browse the repository at this point in the history
Add a fixture that provides numpy, masked, and dask array functions
  • Loading branch information
rpmanser authored Aug 19, 2021
1 parent 4ad37b6 commit dd8e492
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 170 deletions.
1 change: 1 addition & 0 deletions ci/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest==6.2.4
pytest-mpl==0.13
netCDF4==1.5.7
coverage==5.5
dask==2021.2.0
20 changes: 20 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,23 @@ def set_agg_backend():
yield
finally:
matplotlib.pyplot.switch_backend(prev_backend)


@pytest.fixture(params=['dask', 'xarray', 'masked', 'numpy'])
def array_type(request):
"""Return an array type for testing calc functions."""
quantity = metpy.units.units.Quantity
if request.param == 'dask':
dask_array = pytest.importorskip('dask.array', reason='dask.array is not available')
marker = request.node.get_closest_marker('xfail_dask')
if marker is not None:
request.applymarker(pytest.mark.xfail(reason=marker.args[0]))
return lambda d, u, *, mask=None: quantity(dask_array.array(d), u)
elif request.param == 'xarray':
return lambda d, u, *, mask=None: xarray.DataArray(d, attrs={'units': u})
elif request.param == 'masked':
return lambda d, u, *, mask=None: quantity(numpy.ma.array(d, mask=mask), u)
elif request.param == 'numpy':
return lambda d, u, *, mask=None: quantity(numpy.array(d), u)
else:
raise ValueError(f'Unsupported array_type option {request.param}')
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ extension = MPY = flake8_metpy:MetPyChecker
paths = ./tools/flake8-metpy

[tool:pytest]
# https://github.com/matplotlib/pytest-mpl/issues/69
markers = mpl_image_compare
markers = xfail_dask: marks tests as expected to fail with Dask arrays
norecursedirs = build docs .idea
doctest_optionflags = NORMALIZE_WHITESPACE
mpl-results-path = test_output
Expand Down
11 changes: 8 additions & 3 deletions src/metpy/calc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def wind_direction(u, v, convention='from'):
if np.any(mask):
wdir[mask] += units.Quantity(360., 'deg')
# avoid unintended modification of `pint.Quantity` by direct use of magnitude
calm_mask = (np.asarray(u.magnitude) == 0.) & (np.asarray(v.magnitude) == 0.)
calm_mask = (np.asanyarray(u.magnitude) == 0.) & (np.asanyarray(v.magnitude) == 0.)

# np.any check required for legacy numpy which treats 0-d False boolean index as zero
if np.any(calm_mask):
wdir[calm_mask] = units.Quantity(0., 'deg')
Expand Down Expand Up @@ -799,8 +800,12 @@ def smooth_gaussian(scalar_grid, n):
# Assume the last two axes represent the horizontal directions
sgma_seq = [sgma if i > num_ax - 3 else 0 for i in range(num_ax)]

# Compute smoothed field
return gaussian_filter(scalar_grid, sgma_seq, truncate=2 * np.sqrt(2))
filter_args = {'sigma': sgma_seq, 'truncate': 2 * np.sqrt(2)}
if hasattr(scalar_grid, 'mask'):
smoothed = gaussian_filter(scalar_grid.data, **filter_args)
return np.ma.array(smoothed, mask=scalar_grid.mask)
else:
return gaussian_filter(scalar_grid, **filter_args)


@exporter.export
Expand Down
Loading

0 comments on commit dd8e492

Please sign in to comment.