diff --git a/conftest.py b/conftest.py index 4131ebc4399..c043e1f1b46 100644 --- a/conftest.py +++ b/conftest.py @@ -162,3 +162,15 @@ def set_agg_backend(): yield finally: matplotlib.pyplot.switch_backend(prev_backend) + + +@pytest.fixture(params=['dask', 'masked', 'numpy']) +def array_type(request): + """Return an array type for testing calc functions.""" + if request.param == 'dask': + dask_array = pytest.importorskip('dask.array', reason='dask.array is not available') + return dask_array.array + elif request.param == 'masked': + return numpy.ma.array + else: + return numpy.array diff --git a/tests/calc/test_basic.py b/tests/calc/test_basic.py index da33c87b77c..6ea7b879c5a 100644 --- a/tests/calc/test_basic.py +++ b/tests/calc/test_basic.py @@ -19,16 +19,22 @@ from metpy.units import units -def test_wind_comps_basic(): +def test_wind_comps_basic(array_type): """Test the basic wind component calculation.""" - speed = np.array([4, 4, 4, 4, 25, 25, 25, 25, 10.]) * units.mph - dirs = np.array([0, 45, 90, 135, 180, 225, 270, 315, 360]) * units.deg + speed = units.Quantity(array_type([4, 4, 4, 4, 25, 25, 25, 25, 10.]), 'mph') + dirs = units.Quantity(array_type([0, 45, 90, 135, 180, 225, 270, 315, 360]), 'deg') s2 = np.sqrt(2.) u, v = wind_components(speed, dirs) - true_u = np.array([0, -4 / s2, -4, -4 / s2, 0, 25 / s2, 25, 25 / s2, 0]) * units.mph - true_v = np.array([-4, -4 / s2, 0, 4 / s2, 25, 25 / s2, 0, -25 / s2, -10]) * units.mph + true_u = units.Quantity( + array_type([0, -4 / s2, -4, -4 / s2, 0, 25 / s2, 25, 25 / s2, 0]), + 'mph' + ) + true_v = units.Quantity( + array_type([-4, -4 / s2, 0, 4 / s2, 25, 25 / s2, 0, -25 / s2, -10]), + 'mph' + ) assert_array_almost_equal(true_u, u, 4) assert_array_almost_equal(true_v, v, 4)