Skip to content

Commit

Permalink
Implement skipna kwarg in xr.quantile (#3844)
Browse files Browse the repository at this point in the history
* quick fix, no docs, no tests

* added tests

* docstrings

* added whatsnew

* Update doc/whats-new.rst

Co-Authored-By: Maximilian Roos <[email protected]>

* Update doc/whats-new.rst

Co-Authored-By: keewis <[email protected]>

Co-authored-by: Maximilian Roos <[email protected]>
Co-authored-by: keewis <[email protected]>
  • Loading branch information
3 people authored Mar 8, 2020
1 parent 9fbb417 commit cdaac64
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 19 deletions.
7 changes: 6 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ New Features
By `Julia Signell <https://github.com/jsignell>`_.
- :py:meth:`Dataset.where` and :py:meth:`DataArray.where` accept a lambda as a
first argument, which is then called on the input; replicating pandas' behavior.
By `Maximilian Roos <https://github.com/max-sixty>`_
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Implement ``skipna`` in :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile`,
:py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile`
(:issue:`3843`, :pull:`3844`)
By `Aaron Spring <https://github.com/aaronspring>`_.


Bug fixes
~~~~~~~~~
Expand Down
11 changes: 9 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2939,6 +2939,7 @@ def quantile(
dim: Union[Hashable, Sequence[Hashable], None] = None,
interpolation: str = "linear",
keep_attrs: bool = None,
skipna: bool = True,
) -> "DataArray":
"""Compute the qth quantile of the data along the specified dimension.
Expand Down Expand Up @@ -2966,6 +2967,8 @@ def quantile(
If True, the dataset's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
object will be returned without attributes.
skipna : bool, optional
Whether to skip missing values when aggregating.
Returns
-------
Expand All @@ -2978,7 +2981,7 @@ def quantile(
See Also
--------
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile
Examples
--------
Expand Down Expand Up @@ -3015,7 +3018,11 @@ def quantile(
"""

ds = self._to_temp_dataset().quantile(
q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation
q,
dim=dim,
keep_attrs=keep_attrs,
interpolation=interpolation,
skipna=skipna,
)
return self._from_temp_dataset(ds)

Expand Down
13 changes: 11 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5140,7 +5140,13 @@ def sortby(self, variables, ascending=True):
return aligned_self.isel(**indices)

def quantile(
self, q, dim=None, interpolation="linear", numeric_only=False, keep_attrs=None
self,
q,
dim=None,
interpolation="linear",
numeric_only=False,
keep_attrs=None,
skipna=True,
):
"""Compute the qth quantile of the data along the specified dimension.
Expand Down Expand Up @@ -5171,6 +5177,8 @@ def quantile(
object will be returned without attributes.
numeric_only : bool, optional
If True, only apply ``func`` to variables with a numeric dtype.
skipna : bool, optional
Whether to skip missing values when aggregating.
Returns
-------
Expand All @@ -5183,7 +5191,7 @@ def quantile(
See Also
--------
numpy.nanquantile, pandas.Series.quantile, DataArray.quantile
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, DataArray.quantile
Examples
--------
Expand Down Expand Up @@ -5258,6 +5266,7 @@ def quantile(
dim=reduce_dims,
interpolation=interpolation,
keep_attrs=keep_attrs,
skipna=skipna,
)

else:
Expand Down
9 changes: 7 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def fillna(self, value):
out = ops.fillna(self, value)
return out

def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
def quantile(
self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True
):
"""Compute the qth quantile over each array in the groups and
concatenate them together into a new array.
Expand All @@ -582,6 +584,8 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
* higher: ``j``.
* nearest: ``i`` or ``j``, whichever is nearest.
* midpoint: ``(i + j) / 2``.
skipna : bool, optional
Whether to skip missing values when aggregating.
Returns
-------
Expand All @@ -595,7 +599,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
See Also
--------
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile,
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile,
DataArray.quantile
Examples
Expand Down Expand Up @@ -656,6 +660,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
dim=dim,
interpolation=interpolation,
keep_attrs=keep_attrs,
skipna=skipna,
)

return out
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,9 @@ def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
"""
return self.broadcast_equals(other, equiv=equiv)

def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
def quantile(
self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True
):
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements.
Expand Down Expand Up @@ -1725,6 +1727,8 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):

from .computation import apply_ufunc

_quantile_func = np.nanquantile if skipna else np.quantile

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

Expand All @@ -1739,7 +1743,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):

def _wrapper(npa, **kwargs):
# move quantile axis to end. required for apply_ufunc
return np.moveaxis(np.nanquantile(npa, **kwargs), 0, -1)
return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1)

axis = np.arange(-1, -1 * len(dim) - 1, -1)
result = apply_ufunc(
Expand Down
8 changes: 5 additions & 3 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,13 +2368,15 @@ def test_reduce_out(self):
with pytest.raises(TypeError):
orig.mean(out=np.ones(orig.shape))

@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize(
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
)
def test_quantile(self, q, axis, dim):
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
expected = np.nanpercentile(self.dv.values, np.array(q) * 100, axis=axis)
def test_quantile(self, q, axis, dim, skipna):
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna)
_percentile_func = np.nanpercentile if skipna else np.percentile
expected = _percentile_func(self.dv.values, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)
if is_scalar(q):
assert "quantile" not in actual.dims
Expand Down
24 changes: 20 additions & 4 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4697,25 +4697,41 @@ def test_reduce_keepdims(self):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
def test_quantile(self, q):
def test_quantile(self, q, skipna):
ds = create_test_data(seed=123)

for dim in [None, "dim1", ["dim1"]]:
ds_quantile = ds.quantile(q, dim=dim)
ds_quantile = ds.quantile(q, dim=dim, skipna=skipna)
if is_scalar(q):
assert "quantile" not in ds_quantile.dims
else:
assert "quantile" in ds_quantile.dims

for var, dar in ds.data_vars.items():
assert var in ds_quantile
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
assert_identical(
ds_quantile[var], dar.quantile(q, dim=dim, skipna=skipna)
)
dim = ["dim1", "dim2"]
ds_quantile = ds.quantile(q, dim=dim)
ds_quantile = ds.quantile(q, dim=dim, skipna=skipna)
assert "dim3" in ds_quantile.dims
assert all(d not in ds_quantile.dims for d in dim)

@pytest.mark.parametrize("skipna", [True, False])
def test_quantile_skipna(self, skipna):
q = 0.1
dim = "time"
ds = Dataset({"a": ([dim], np.arange(0, 11))})
ds = ds.where(ds >= 1)

result = ds.quantile(q=q, dim=dim, skipna=skipna)

value = 1.9 if skipna else np.nan
expected = Dataset({"a": value}, coords={"quantile": q})
assert_identical(result, expected)

@requires_bottleneck
def test_rank(self):
ds = create_test_data(seed=1234)
Expand Down
8 changes: 5 additions & 3 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,14 +1511,16 @@ def test_reduce(self):
with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
v.mean(dim="x", allow_lazy=False)

@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize(
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
)
def test_quantile(self, q, axis, dim):
def test_quantile(self, q, axis, dim, skipna):
v = Variable(["x", "y"], self.d)
actual = v.quantile(q, dim=dim)
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
actual = v.quantile(q, dim=dim, skipna=skipna)
_percentile_func = np.nanpercentile if skipna else np.percentile
expected = _percentile_func(self.d, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)

@requires_dask
Expand Down

0 comments on commit cdaac64

Please sign in to comment.