Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use numbagg for ffill by default #8389

Merged
merged 16 commits into from
Nov 25, 2023
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg by
default, which is up to 5x faster where parallelization is possible. (:pull:`8339`)
By `Maximilian Roos <https://github.com/max-sixty>`_.

.. _whats-new.2023.11.0:

v2023.11.0 (Nov 16, 2023)
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
# this avoids the need to get involved in zarr synchronization / locking
# From zarr docs:
# "If each worker in a parallel computation is writing to a separate
# region of the array, and if region boundaries are perfectly aligned
# "If each worker in a parallel computation is writing to a
# separate region of the array, and if region boundaries are perfectly aligned
# with chunk boundaries, then no synchronization is required."
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import bottleneck
import dask.array as da
import numpy as np

from xarray.core.duck_array_ops import _push

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
Expand All @@ -85,7 +86,7 @@ def _fill_with_last_one(a, b):

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=bottleneck.push,
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
Expand Down
41 changes: 37 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from numpy import concatenate as _concatenate
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core import dask_array_ops, dtypes, nputils, pycompat
from xarray.core.options import OPTIONS
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.pycompat import array_type, is_duck_dask_array
from xarray.core.utils import is_duck_array, module_available
Expand Down Expand Up @@ -688,13 +690,44 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)


def push(array, n, axis):
from bottleneck import push
def _push(array, n: int | None = None, axis: int = -1):
"""
Use either bottleneck or numbagg depending on options & what's available
"""

if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
raise RuntimeError(
"ffill & bfill requires bottleneck or numbagg to be enabled."
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
)
if OPTIONS["use_numbagg"] and module_available("numbagg"):
import numbagg

if pycompat.mod_version("numbagg") < Version("0.6.2"):
warnings.warn(
f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead."
)
else:
return numbagg.ffill(array, limit=n, axis=axis)

# work around for bottleneck 178
limit = n if n is not None else array.shape[axis]

import bottleneck as bn

return bn.push(array, limit, axis)


def push(array, n, axis):
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
raise RuntimeError(
"ffill & bfill requires bottleneck or numbagg to be enabled."
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
)
if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
else:
return push(array, n, axis)
return _push(array, n, axis)


def _first_last_wrapper(array, *, axis, op, keepdims):
Expand Down
12 changes: 1 addition & 11 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from xarray.core.common import _contains_datetime_like_objects, ones_like
from xarray.core.computation import apply_ufunc
from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.options import _get_keep_attrs
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.types import Interp1dOptions, InterpOptions
from xarray.core.utils import OrderedSet, is_scalar
Expand Down Expand Up @@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1):

def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"ffill requires bottleneck to be enabled."
" Call `xr.set_options(use_bottleneck=True)` to enable it."
)

axis = arr.get_axis_num(dim)

Expand All @@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None):

def bfill(arr, dim=None, limit=None):
"""backfill missing values"""
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"bfill requires bottleneck to be enabled."
" Call `xr.set_options(use_bottleneck=True)` to enable it."
)

axis = arr.get_axis_num(dim)

Expand Down
34 changes: 16 additions & 18 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import warnings
from typing import Callable

import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from packaging.version import Version

from xarray.core import pycompat
from xarray.core.utils import module_available

# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
Expand All @@ -25,15 +29,6 @@
bn = np
_BOTTLENECK_AVAILABLE = False

try:
import numbagg

_HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
except ImportError:
# use numpy methods instead
numbagg = np # type: ignore
_HAS_NUMBAGG = False


def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
Expand Down Expand Up @@ -171,29 +166,32 @@ def __setitem__(self, key, value):
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)


def _create_method(name, npmodule=np):
def _create_method(name, npmodule=np) -> Callable:
def f(values, axis=None, **kwargs):
dtype = kwargs.get("dtype", None)
bn_func = getattr(bn, name, None)
nba_func = getattr(numbagg, name, None)

if (
_HAS_NUMBAGG
module_available("numbagg")
and pycompat.mod_version("numbagg") >= Version("0.5.0")
and OPTIONS["use_numbagg"]
and isinstance(values, np.ndarray)
and nba_func is not None
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
# TODO: bool?
and values.dtype.kind in "uifc"
# and values.dtype.isnative
and (dtype is None or np.dtype(dtype) == values.dtype)
):
# numbagg does not take care dtype, ddof
kwargs.pop("dtype", None)
kwargs.pop("ddof", None)
result = nba_func(values, axis=axis, **kwargs)
elif (
import numbagg

nba_func = getattr(numbagg, name, None)
if nba_func is not None:
# numbagg does not take care dtype, ddof
kwargs.pop("dtype", None)
kwargs.pop("ddof", None)
return nba_func(values, axis=axis, **kwargs)
if (
_BOTTLENECK_AVAILABLE
and OPTIONS["use_bottleneck"]
and isinstance(values, np.ndarray)
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
integer_types = (int, np.integer)

if TYPE_CHECKING:
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"]
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic


Expand Down Expand Up @@ -47,6 +47,9 @@ def __init__(self, mod: ModType) -> None:
duck_array_type = (duck_array_module.SparseArray,)
elif mod == "cubed":
duck_array_type = (duck_array_module.Array,)
# Not a duck array module, but using this system regardless, to get lazy imports
elif mod == "numbagg":
duck_array_type = ()
else:
raise NotImplementedError

Expand Down
50 changes: 26 additions & 24 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
import numpy as np
from packaging.version import Version

from xarray.core import pycompat
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.types import T_DataWithCoords

try:
import numbagg
from numbagg import move_exp_nanmean, move_exp_nansum

_NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
except ImportError:
_NUMBAGG_VERSION = None
from xarray.core.utils import module_available


def _get_alpha(
Expand Down Expand Up @@ -83,17 +77,17 @@ def __init__(
window_type: str = "span",
min_weight: float = 0.0,
):
if _NUMBAGG_VERSION is None:
if not module_available("numbagg"):
raise ImportError(
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
)
elif _NUMBAGG_VERSION < Version("0.2.1"):
elif pycompat.mod_version("numbagg") < Version("0.2.1"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we boost min supported numbagg version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we basically keep the 12 month cycle — it's some good infra, and I would like to be a good citizen for that process rather than deviate to save myself some boilerplate...

raise ImportError(
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed"
)
elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0:
raise ImportError(
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed"
)

self.obj: T_DataWithCoords = obj
Expand Down Expand Up @@ -127,13 +121,15 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

import numbagg

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dim_order = self.obj.dims

return apply_ufunc(
move_exp_nanmean,
numbagg.move_exp_nanmean,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
Expand Down Expand Up @@ -163,13 +159,15 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

import numbagg

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dim_order = self.obj.dims

return apply_ufunc(
move_exp_nansum,
numbagg.move_exp_nansum,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
Expand All @@ -194,10 +192,12 @@ def std(self) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
import numbagg

dim_order = self.obj.dims

return apply_ufunc(
Expand Down Expand Up @@ -225,12 +225,12 @@ def var(self) -> T_DataWithCoords:
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nanvar,
Expand Down Expand Up @@ -258,11 +258,12 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nancov,
Expand Down Expand Up @@ -291,11 +292,12 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nancorr,
Expand Down
7 changes: 6 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _importorskip(
mod = importlib.import_module(modname)
has = True
if minversion is not None:
if Version(mod.__version__) < Version(minversion):
v = getattr(mod, "__version__", "999")
if Version(v) < Version(minversion):
raise ImportError("Minimum version not satisfied")
except ImportError:
has = False
Expand Down Expand Up @@ -96,6 +97,10 @@ def _importorskip(
requires_scipy_or_netCDF4 = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
requires_numbagg_or_bottleneck = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
# _importorskip does not work for development versions
has_pandas_version_two = Version(pd.__version__).major >= 2
requires_pandas_version_two = pytest.mark.skipif(
Expand Down
Loading
Loading