diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1d239e18fcd..f811c7d15f0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` + now work with dask Variables. + By `Deepak Cherian `_. Bug fixes diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fdddde773c1..1793dd2d94d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5166,11 +5166,7 @@ def quantile( new = self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes ) - if "quantile" in new.dims: - new.coords["quantile"] = Variable("quantile", q) - else: - new.coords["quantile"] = q - return new + return new.assign_coords(quantile=q) def rank(self, dim, pct=False, keep_attrs=None): """Ranks the data. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 55e8f64d56c..041c303dd3a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1716,40 +1716,45 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile """ - if isinstance(self.data, dask_array_type): - raise TypeError( - "quantile does not work for arrays stored as dask " - "arrays. Load the data via .compute() or .load() " - "prior to calling this method." - ) - q = np.asarray(q, dtype=np.float64) - - new_dims = list(self.dims) - if dim is not None: - axis = self.get_axis_num(dim) - if utils.is_scalar(dim): - new_dims.remove(dim) - else: - for d in dim: - new_dims.remove(d) - else: - axis = None - new_dims = [] - - # Only add the quantile dimension if q is array-like - if q.ndim != 0: - new_dims = ["quantile"] + new_dims - - qs = np.nanpercentile( - self.data, q * 100.0, axis=axis, interpolation=interpolation - ) + from .computation import apply_ufunc if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - attrs = self._attrs if keep_attrs else None - return Variable(new_dims, qs, attrs) + scalar = utils.is_scalar(q) + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + + if dim is None: + dim = self.dims + + if utils.is_scalar(dim): + dim = [dim] + + def _wrapper(npa, **kwargs): + # move quantile axis to end. required for apply_ufunc + return np.moveaxis(np.nanpercentile(npa, **kwargs), 0, -1) + + axis = np.arange(-1, -1 * len(dim) - 1, -1) + result = apply_ufunc( + _wrapper, + self, + input_core_dims=[dim], + exclude_dims=set(dim), + output_core_dims=[["quantile"]], + output_dtypes=[np.float64], + output_sizes={"quantile": len(q)}, + dask="parallelized", + kwargs={"q": q * 100, "axis": axis, "interpolation": interpolation}, + ) + + # for backward compatibility + result = result.transpose("quantile", ...) + if scalar: + result = result.squeeze("quantile") + if keep_attrs: + result.attrs = self._attrs + return result def rank(self, dim, pct=False): """Ranks the data. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ad98792372e..a1e34abd0d5 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -15,6 +15,8 @@ from xarray.core import dtypes from xarray.core.common import full_like from xarray.core.indexes import propagate_indexes +from xarray.core.utils import is_scalar + from xarray.tests import ( LooseVersion, ReturnItem, @@ -2330,17 +2332,20 @@ def test_reduce_out(self): with pytest.raises(TypeError): orig.mean(out=np.ones(orig.shape)) - def test_quantile(self): - for q in [0.25, [0.50], [0.25, 0.75]]: - for axis, dim in zip( - [None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]] - ): - actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True) - expected = np.nanpercentile( - self.dv.values, np.array(q) * 100, axis=axis - ) - np.testing.assert_allclose(actual.values, expected) - assert actual.attrs == self.attrs + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize( + "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + ) + def test_quantile(self, q, axis, dim): + actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True) + expected = np.nanpercentile(self.dv.values, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) + if is_scalar(q): + assert "quantile" not in actual.dims + else: + assert "quantile" in actual.dims + + assert actual.attrs == self.attrs def test_reduce_keep_attrs(self): # Test dropped attrs diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e8fe768b783..d8282f58051 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -28,6 +28,7 @@ from xarray.core.common import duck_array_ops, full_like from xarray.core.npcompat import IS_NEP18_ACTIVE from xarray.core.pycompat import integer_types +from xarray.core.utils import is_scalar from . import ( InaccessibleArray, @@ -4575,21 +4576,24 @@ def test_reduce_keepdims(self): ) assert_identical(expected, actual) - def test_quantile(self): - + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + def test_quantile(self, q): ds = create_test_data(seed=123) - for q in [0.25, [0.50], [0.25, 0.75]]: - for dim in [None, "dim1", ["dim1"]]: - ds_quantile = ds.quantile(q, dim=dim) - assert "quantile" in ds_quantile - for var, dar in ds.data_vars.items(): - assert var in ds_quantile - assert_identical(ds_quantile[var], dar.quantile(q, dim=dim)) - dim = ["dim1", "dim2"] + for dim in [None, "dim1", ["dim1"]]: ds_quantile = ds.quantile(q, dim=dim) - assert "dim3" in ds_quantile.dims - assert all(d not in ds_quantile.dims for d in dim) + if is_scalar(q): + assert "quantile" not in ds_quantile.dims + else: + assert "quantile" in ds_quantile.dims + + for var, dar in ds.data_vars.items(): + assert var in ds_quantile + assert_identical(ds_quantile[var], dar.quantile(q, dim=dim)) + dim = ["dim1", "dim2"] + ds_quantile = ds.quantile(q, dim=dim) + assert "dim3" in ds_quantile.dims + assert all(d not in ds_quantile.dims for d in dim) @requires_bottleneck def test_rank(self): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index ee8d54e567e..245dc1acc42 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -22,6 +22,7 @@ PandasIndexAdapter, VectorizedIndexer, ) +from xarray.core.pycompat import dask_array_type from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable from xarray.tests import requires_bottleneck @@ -1492,23 +1493,31 @@ def test_reduce(self): with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"): v.mean(dim="x", allow_lazy=False) - def test_quantile(self): + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize( + "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + ) + def test_quantile(self, q, axis, dim): v = Variable(["x", "y"], self.d) - for q in [0.25, [0.50], [0.25, 0.75]]: - for axis, dim in zip( - [None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]] - ): - actual = v.quantile(q, dim=dim) + actual = v.quantile(q, dim=dim) + expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) - expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis) - np.testing.assert_allclose(actual.values, expected) + @requires_dask + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize("axis, dim", [[1, "y"], [[1], ["y"]]]) + def test_quantile_dask(self, q, axis, dim): + v = Variable(["x", "y"], self.d).chunk({"x": 2}) + actual = v.quantile(q, dim=dim) + assert isinstance(actual.data, dask_array_type) + expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) @requires_dask - def test_quantile_dask_raises(self): - # regression for GH1524 - v = Variable(["x", "y"], self.d).chunk(2) + def test_quantile_chunked_dim_error(self): + v = Variable(["x", "y"], self.d).chunk({"x": 2}) - with raises_regex(TypeError, "arrays stored as dask"): + with raises_regex(ValueError, "dimension 'x'"): v.quantile(0.5, dim="x") @requires_dask