Skip to content

Commit

Permalink
rolling.construct: Add sliding_window_kwargs to pipe arguments do…
Browse files Browse the repository at this point in the history
…wn to `sliding_window_view` (#9720)

* sliding_window_view: add new `automatic_rechunk` kwarg

Closes #9550
xref #4325

* Switch to ``sliding_window_kwargs``

* Add one more

* better docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename to sliding_window_view_kwargs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dcherian and pre-commit-ci[bot] authored Nov 18, 2024
1 parent 3f0ddc1 commit 993300b
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 27 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ New Features
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Add new ``automatic_rechunk`` kwarg to :py:meth:`DataArrayRolling.construct` and
:py:meth:`DatasetRolling.construct`. This is only useful on ``dask>=2024.11.0``
(:issue:`9550`). By `Deepak Cherian <https://github.com/dcherian>`_.
- Optimize ffill, bfill with dask when limit is specified
(:pull:`9771`).
By `Joseph Nowak <https://github.com/josephnowak>`_, and
Expand Down
16 changes: 16 additions & 0 deletions xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ def reshape_blockwise(
return reshape_blockwise(x, shape=shape, chunks=chunks)
else:
return x.reshape(shape)


def sliding_window_view(
x, window_shape, axis=None, *, automatic_rechunk=True, **kwargs
):
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
# Note that subok, writeable are unsupported by dask, so we ignore those in kwargs
from dask.array.lib.stride_tricks import sliding_window_view

if module_available("dask", "2024.11.0"):
return sliding_window_view(
x, window_shape=window_shape, axis=axis, automatic_rechunk=automatic_rechunk
)
else:
# automatic_rechunk is not supported
return sliding_window_view(x, window_shape=window_shape, axis=axis)
33 changes: 24 additions & 9 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
transpose,
unravel_index,
)
from numpy.lib.stride_tricks import sliding_window_view # noqa: F401
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray import pycompat
Expand Down Expand Up @@ -92,19 +91,25 @@ def _dask_or_eager_func(
name,
eager_module=np,
dask_module="dask.array",
dask_only_kwargs=tuple(),
numpy_only_kwargs=tuple(),
):
"""Create a function that dispatches to dask for dask array inputs."""

def f(*args, **kwargs):
if any(is_duck_dask_array(a) for a in args):
if dask_available and any(is_duck_dask_array(a) for a in args):
mod = (
import_module(dask_module)
if isinstance(dask_module, str)
else dask_module
)
wrapped = getattr(mod, name)
for kwarg in numpy_only_kwargs:
kwargs.pop(kwarg, None)
else:
wrapped = getattr(eager_module, name)
for kwarg in dask_only_kwargs:
kwargs.pop(kwarg, None)
return wrapped(*args, **kwargs)

return f
Expand All @@ -122,6 +127,22 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")

# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
# TODO: replacing breaks iris + dask tests
masked_invalid = _dask_or_eager_func(
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
)

# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
# so we need to hand-code this.
sliding_window_view = _dask_or_eager_func(
"sliding_window_view",
eager_module=np.lib.stride_tricks,
dask_module=dask_array_compat,
dask_only_kwargs=("automatic_rechunk",),
numpy_only_kwargs=("subok", "writeable"),
)


def round(array):
xp = get_array_namespace(array)
Expand Down Expand Up @@ -170,12 +191,6 @@ def notnull(data):
return ~isnull(data)


# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
masked_invalid = _dask_or_eager_func(
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
)


def trapz(y, x, axis):
if axis < 0:
axis = y.ndim + axis
Expand Down
Loading

0 comments on commit 993300b

Please sign in to comment.