From 1fc24f82debe1aa4c90a33d058ab5b905654f234 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Tue, 17 Oct 2023 11:48:47 -0700 Subject: [PATCH] --- xarray/core/rolling_exp.py | 19 ++++++++++--------- xarray/tests/__init__.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 7c747201856..d1e613e98a2 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -4,6 +4,7 @@ from typing import Any, Generic import numpy as np +from packaging.version import Version from xarray.core.computation import apply_ufunc from xarray.core.options import _get_keep_attrs @@ -14,9 +15,9 @@ import numbagg from numbagg import move_exp_nanmean, move_exp_nansum - has_numbagg = numbagg.__version__ + has_numbagg: Version | None = Version(numbagg.__version__) except ImportError: - has_numbagg = False + has_numbagg = None def _get_alpha( @@ -99,15 +100,15 @@ def __init__( window_type: str = "span", min_weight: float = 0.0, ): - if has_numbagg is False: + if has_numbagg is None: raise ImportError( "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed" ) - elif has_numbagg < "0.2.1": + elif has_numbagg < Version("0.2.1"): raise ImportError( f"numbagg >= 0.2.1 is required for `rolling_exp` but currently version {has_numbagg} is installed" ) - elif has_numbagg < "0.3.1" and min_weight > 0: + elif has_numbagg < Version("0.3.1") and min_weight > 0: raise ImportError( f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {has_numbagg} is installed" ) @@ -210,7 +211,7 @@ def std(self) -> T_DataWithCoords: Dimensions without coordinates: x """ - if has_numbagg is False or has_numbagg < "0.4.0": + if has_numbagg is None or has_numbagg < Version("0.4.0"): raise ImportError( f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {has_numbagg} is installed" ) @@ -242,7 +243,7 @@ def var(self) -> T_DataWithCoords: Dimensions without coordinates: x """ - if has_numbagg is False or has_numbagg < "0.4.0": + if has_numbagg is None or has_numbagg < Version("0.4.0"): raise ImportError( f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {has_numbagg} is installed" ) @@ -274,7 +275,7 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: Dimensions without coordinates: x """ - if has_numbagg is False or has_numbagg < "0.4.0": + if has_numbagg is None or has_numbagg < Version("0.4.0"): raise ImportError( f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed" ) @@ -307,7 +308,7 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: Dimensions without coordinates: x """ - if has_numbagg is False or has_numbagg < "0.4.0": + if has_numbagg is None or has_numbagg < Version("0.4.0"): raise ImportError( f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed" ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7e1b964ecba..07ba0be6a8c 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -75,7 +75,7 @@ def _importorskip( has_zarr, requires_zarr = _importorskip("zarr") has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") -has_numbagg, requires_numbagg = _importorskip("numbagg") +has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") has_cupy, requires_cupy = _importorskip("cupy")