Skip to content

Commit

Permalink
Support first, last with dask arrays (#7562)
Browse files Browse the repository at this point in the history
* Support first, last with dask arrays

Use dask.array.reduction. For this we need to add support
for the `keepdims` kwarg to `nanfirst` and `nanlast`.
Even though the final result is always keepdims=False,
dask runs the intermediate steps with keepdims=True.

* Don't provide meta.

It would need to account for shape change.
  • Loading branch information
dcherian authored Mar 3, 2023
1 parent 43ba095 commit 830ee6d
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 21 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ New Features

- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Support dask arrays in ``first`` and ``last`` reductions.
By `Deepak Cherian <https://github.com/dcherian>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from functools import partial

from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

from xarray.core import dtypes, nputils


Expand Down Expand Up @@ -92,3 +96,36 @@ def _fill_with_last_one(a, b):
axis=axis,
dtype=array.dtype,
)


def _first_last_wrapper(array, *, axis, op, keepdims):
return op(array, axis, keepdims=keepdims)


def _first_or_last(darray, axis, op):
import dask.array

# This will raise the same error message seen for numpy
axis = normalize_axis_index(axis, darray.ndim)

wrapped_op = partial(_first_last_wrapper, op=op)
return dask.array.reduction(
darray,
chunk=wrapped_op,
aggregate=wrapped_op,
axis=axis,
dtype=darray.dtype,
keepdims=False, # match numpy version
)


def nanfirst(darray, axis):
from xarray.core.duck_array_ops import nanfirst

return _first_or_last(darray, axis, op=nanfirst)


def nanlast(darray, axis):
from xarray.core.duck_array_ops import nanlast

return _first_or_last(darray, axis, op=nanlast)
19 changes: 8 additions & 11 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import datetime
import inspect
import warnings
from functools import partial
from importlib import import_module

import numpy as np
Expand Down Expand Up @@ -637,27 +636,25 @@ def cumsum(array, axis=None, **kwargs):
return _nd_cum_func(cumsum_1d, array, axis, **kwargs)


_fail_on_dask_array_input_skipna = partial(
fail_on_dask_array_input,
msg="%r with skipna=True is not yet implemented on dask arrays",
)


def first(values, axis, skipna=None):
"""Return the first non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
# only bother for dtypes that can hold NaN
_fail_on_dask_array_input_skipna(values)
return nanfirst(values, axis)
if is_duck_dask_array(values):
return dask_array_ops.nanfirst(values, axis)
else:
return nanfirst(values, axis)
return take(values, 0, axis=axis)


def last(values, axis, skipna=None):
"""Return the last non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
# only bother for dtypes that can hold NaN
_fail_on_dask_array_input_skipna(values)
return nanlast(values, axis)
if is_duck_dask_array(values):
return dask_array_ops.nanlast(values, axis)
else:
return nanlast(values, axis)
return take(values, -1, axis=axis)


Expand Down
20 changes: 16 additions & 4 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,29 @@ def _select_along_axis(values, idx, axis):
return values[sl]


def nanfirst(values, axis):
def nanfirst(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
axis = normalize_axis_index(axis, values.ndim)
idx_first = np.argmax(~pd.isnull(values), axis=axis)
return _select_along_axis(values, idx_first, axis)
result = _select_along_axis(values, idx_first, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result


def nanlast(values, axis):
def nanlast(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
axis = normalize_axis_index(axis, values.ndim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
return _select_along_axis(values, idx_last, axis)
result = _select_along_axis(values, idx_last, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result


def inverse_permutation(indices):
Expand Down
15 changes: 10 additions & 5 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,17 +549,22 @@ def test_rolling(self):
actual = v.rolling(x=2).mean()
self.assertLazyAndAllClose(expected, actual)

def test_groupby_first(self):
@pytest.mark.parametrize("func", ["first", "last"])
def test_groupby_first_last(self, func):
method = operator.methodcaller(func)
u = self.eager_array
v = self.lazy_array

for coords in [u.coords, v.coords]:
coords["ab"] = ("x", ["a", "a", "b", "b"])
with pytest.raises(NotImplementedError, match=r"dask"):
v.groupby("ab").first()
expected = u.groupby("ab").first()
expected = method(u.groupby("ab"))

with raise_if_dask_computes():
actual = method(v.groupby("ab"))
self.assertLazyAndAllClose(expected, actual)

with raise_if_dask_computes():
actual = v.groupby("ab").first(skipna=False)
actual = method(v.groupby("ab"))
self.assertLazyAndAllClose(expected, actual)

def test_reindex(self):
Expand Down
29 changes: 28 additions & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ class TestOps:
def setUp(self):
self.x = array(
[
[[nan, nan, 2.0, nan], [nan, 5.0, 6.0, nan], [8.0, 9.0, 10.0, nan]],
[
[nan, nan, 2.0, nan],
[nan, 5.0, 6.0, nan],
[8.0, 9.0, 10.0, nan],
],
[
[nan, 13.0, 14.0, 15.0],
[nan, 17.0, 18.0, nan],
Expand Down Expand Up @@ -128,6 +132,29 @@ def test_all_nan_arrays(self):
assert np.isnan(mean([np.nan, np.nan]))


@requires_dask
class TestDaskOps(TestOps):
@pytest.fixture(autouse=True)
def setUp(self):
import dask.array

self.x = dask.array.from_array(
[
[
[nan, nan, 2.0, nan],
[nan, 5.0, 6.0, nan],
[8.0, 9.0, 10.0, nan],
],
[
[nan, 13.0, 14.0, 15.0],
[nan, 17.0, 18.0, nan],
[nan, 21.0, nan, nan],
],
],
chunks=(2, 1, 2),
)


def test_cumsum_1d():
inputs = np.array([0, 1, 2, 3])
expected = np.array([0, 1, 3, 6])
Expand Down

0 comments on commit 830ee6d

Please sign in to comment.