Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty committed Oct 17, 2023
1 parent 892a8c8 commit 1fc24f8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
19 changes: 10 additions & 9 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Generic

import numpy as np
from packaging.version import Version

from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
Expand All @@ -14,9 +15,9 @@
import numbagg
from numbagg import move_exp_nanmean, move_exp_nansum

has_numbagg = numbagg.__version__
has_numbagg: Version | None = Version(numbagg.__version__)
except ImportError:
has_numbagg = False
has_numbagg = None


def _get_alpha(
Expand Down Expand Up @@ -99,15 +100,15 @@ def __init__(
window_type: str = "span",
min_weight: float = 0.0,
):
if has_numbagg is False:
if has_numbagg is None:
raise ImportError(
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
)
elif has_numbagg < "0.2.1":
elif has_numbagg < Version("0.2.1"):
raise ImportError(
f"numbagg >= 0.2.1 is required for `rolling_exp` but currently version {has_numbagg} is installed"
)
elif has_numbagg < "0.3.1" and min_weight > 0:
elif has_numbagg < Version("0.3.1") and min_weight > 0:
raise ImportError(
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {has_numbagg} is installed"
)
Expand Down Expand Up @@ -210,7 +211,7 @@ def std(self) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if has_numbagg is False or has_numbagg < "0.4.0":
if has_numbagg is None or has_numbagg < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {has_numbagg} is installed"
)
Expand Down Expand Up @@ -242,7 +243,7 @@ def var(self) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if has_numbagg is False or has_numbagg < "0.4.0":
if has_numbagg is None or has_numbagg < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {has_numbagg} is installed"
)
Expand Down Expand Up @@ -274,7 +275,7 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if has_numbagg is False or has_numbagg < "0.4.0":
if has_numbagg is None or has_numbagg < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed"
)
Expand Down Expand Up @@ -307,7 +308,7 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if has_numbagg is False or has_numbagg < "0.4.0":
if has_numbagg is None or has_numbagg < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed"
)
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _importorskip(
has_zarr, requires_zarr = _importorskip("zarr")
has_fsspec, requires_fsspec = _importorskip("fsspec")
has_iris, requires_iris = _importorskip("iris")
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0")
has_seaborn, requires_seaborn = _importorskip("seaborn")
has_sparse, requires_sparse = _importorskip("sparse")
has_cupy, requires_cupy = _importorskip("cupy")
Expand Down

0 comments on commit 1fc24f8

Please sign in to comment.