Skip to content

Commit

Permalink
🧹 reduce dependencies (#572)
Browse files Browse the repository at this point in the history
* Update requirements.txt

* make eofs not requirement

* Update climpred-dev.yml

* faster rm_poly

* rm xrft from requirements

* blacken notebooks

* rm tqdm from ci/requirements/docs_notebooks.yml
  • Loading branch information
aaronspring authored Mar 10, 2021
1 parent 73566ec commit 98ea010
Show file tree
Hide file tree
Showing 22 changed files with 5,603 additions and 2,606 deletions.
4 changes: 2 additions & 2 deletions ci/requirements/climpred-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
# Miscellaneous
- lxml
- tqdm
- cftime>=1.1.2 # Optimization changes to make cftime fast.
# Numerics
- numpy
- pandas
Expand Down Expand Up @@ -49,13 +50,12 @@ dependencies:
# Statistics
- eofs
- esmtools>=1.1.3
- xrft=0.2.0 # New version not compatible with xarray 0.16.1
- xskillscore>=0.0.18
# Visualization
- matplotlib
- nc-time-axis
- pip
- pip:
- cftime>=1.1.2 # Optimization changes to make cftime fast.
- pytest-tldr
- pytest-lazy-fixture
- nb_black # notebook linting
1 change: 0 additions & 1 deletion ci/requirements/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ dependencies:
- python=3.8
- bottleneck
- cartopy>=0.18.0
- esmtools>=1.1.3
- importlib_metadata
- jupyterlab
- matplotlib
Expand Down
6 changes: 3 additions & 3 deletions ci/requirements/docs_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ dependencies:
- xesmf
- esmpy
- bottleneck
- cartopy>=0.18.0
- esmtools>=1.1.3
- importlib_metadata
- jupyterlab
- matplotlib
Expand All @@ -22,9 +20,11 @@ dependencies:
- sphinxcontrib-napoleon
- sphinx_rtd_theme
- toolz
- tqdm
- xarray>=0.16.1
- xrft # used for varweighted_mean_period
- esmtools>=1.1.3
- pip
- pip:
# Install latest version of climpred.
- -e ../..
- nb_black
1 change: 0 additions & 1 deletion ci/requirements/minimum-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- dask
- eofs
- esmpy=*=mpi* # Ensures MPI works with version of esmpy.
- esmtools>=1.1.3
- ipython
- matplotlib
- nc-time-axis
Expand Down
12 changes: 11 additions & 1 deletion climpred/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from .metrics import ALL_METRICS, METRIC_ALIASES
from .prediction import compute_hindcast, compute_perfect_model
from .reference import compute_climatology, compute_persistence
from .stats import dpp, varweighted_mean_period
from .stats import dpp

try:
from .stats import varweighted_mean_period
except ImportError:
varweighted_mean_period = None
from .utils import (
_transpose_and_rechunk_to,
assign_attrs,
Expand Down Expand Up @@ -1261,6 +1266,11 @@ def varweighted_mean_period_threshold(control, sig=95, iterations=500, time_dim=
* climpred.bootstrap._bootstrap_func
* climpred.stats.varweighted_mean_period
"""
if varweighted_mean_period is None:
raise ImportError(
"xrft is not installed; see"
"https://xrft.readthedocs.io/en/latest/installation.html"
)
return _bootstrap_func(
varweighted_mean_period,
control,
Expand Down
11 changes: 10 additions & 1 deletion climpred/relative_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import numpy as np
import xarray as xr
from eofs.xarray import Eof

try:
from eofs.xarray import Eof
except ImportError:
Eof = None


def _relative_entropy_formula(sigma_b, sigma_x, mu_x, mu_b, neofs):
Expand Down Expand Up @@ -110,6 +114,11 @@ def compute_relative_entropy(
Returns:
rel_ent (xr.Dataset): relative entropy
"""
if Eof is None:
raise ImportError(
"eofs is not installed; see"
"https://ajdawson.github.io/eofs/latest/index.html"
)
# Defaults
if neofs is None:
neofs = initialized.member.size
Expand Down
2 changes: 1 addition & 1 deletion climpred/smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def spatial_smoothing_xesmf(
if xe is None:
raise ImportError(
"xesmf is not installed; see"
"https://xesmf.readthedocs.io/en/latest/installation.html"
"https://pangeo-xesmf.readthedocs.io/en/latest/installation.html"
)

def _regrid_it(da, d_lon, d_lat, **kwargs):
Expand Down
54 changes: 48 additions & 6 deletions climpred/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,49 @@

import numpy as np
import xarray as xr
from esmtools.stats import corr
from xrft import power_spectrum
from xskillscore import pearson_r

try:
from xrft import power_spectrum
except ImportError:
power_spectrum = None
from .checks import is_xarray


def rm_poly(ds, dim="time", deg=2, **kwargs):
"""Remove degree polynomial along dimension dim from ds."""
coefficients = ds.polyfit(dim, deg=deg, **kwargs)
coord = ds[dim]
fits = []
if isinstance(ds, xr.Dataset):
for v in coefficients:
name = v.replace("_polyfit_coefficients", "")
fit = xr.polyval(coord, coefficients[v]).rename(name)
fits.append(fit)
fits = xr.merge(fits)
elif isinstance(ds, xr.DataArray):
name = ds.name
v = list(coefficients.data_vars)[0]
fits = xr.polyval(coord, coefficients[v]).rename(name)
ds_rm_poly = ds - fits
return ds_rm_poly


def rm_trend(ds, dim="time", **kwargs):
"""Remove degree polynomial along dimension dim from ds."""
return rm_poly(ds, dim=dim, deg=1, **kwargs)


@is_xarray(0)
def decorrelation_time(da, r=20, dim="time"):
def decorrelation_time(da, iterations=20, dim="time"):
"""Calculate the decorrelaton time of a time series.
.. math::
\\tau_{d} = 1 + 2 * \\sum_{k=1}^{r}(\\alpha_{k})^{k}
Args:
da (xarray object): Time series.
r (optional int): Number of iterations to run the above formula.
da (xarray object): input.
iterations (optional int): Number of iterations to run the above formula.
dim (optional str): Time dimension for xarray object.
Returns:
Expand All @@ -29,10 +56,20 @@ def decorrelation_time(da, r=20, dim="time"):
p.373
"""

def _lag_corr(x, y, dim, lead):
"""Helper function to shift the two time series and correlate."""
N = x[dim].size
normal = x.isel({dim: slice(0, N - lead)})
shifted = y.isel({dim: slice(0 + lead, N)})
# Align dimensions for xarray operation
shifted[dim] = normal[dim]
return pearson_r(normal, shifted, dim)

one = xr.ones_like(da.isel({dim: 0}))
one = one.where(da.isel({dim: 0}).notnull())
return one + 2 * xr.concat(
[corr(da, da, dim=dim, lead=i) ** i for i in range(1, r)], "it"
[_lag_corr(da, da, dim=dim, lead=i) ** i for i in range(1, iterations)], "it"
).sum("it")


Expand Down Expand Up @@ -148,6 +185,11 @@ def varweighted_mean_period(da, dim="time", **kwargs):
See also:
https://xrft.readthedocs.io/en/latest/api.html#xrft.xrft.power_spectrum
"""
if power_spectrum is None:
raise ImportError(
"xrft is not installed; see"
"https://xrft.readthedocs.io/en/latest/installation.html"
)
if isinstance(da, xr.Dataset):
raise ValueError("require xr.DataArray, try xr.Dataset.map(func)")
da = da.fillna(0.0)
Expand Down
7 changes: 0 additions & 7 deletions climpred/tests/test_hindcast_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@
from climpred.bootstrap import bootstrap_hindcast
from climpred.comparisons import HINDCAST_COMPARISONS
from climpred.constants import CLIMPRED_DIMS
from climpred.metrics import DETERMINISTIC_HINDCAST_METRICS
from climpred.prediction import compute_hindcast
from climpred.reference import compute_persistence

# uacc is sqrt(MSSS), fails when MSSS negative
DETERMINISTIC_HINDCAST_METRICS = DETERMINISTIC_HINDCAST_METRICS.copy()
DETERMINISTIC_HINDCAST_METRICS.remove("uacc")

ITERATIONS = 2

category_edges = np.array([0, 0.5, 1])


def test_compute_hindcast_lead0_lead1(
hind_ds_initialized_1d, hind_ds_initialized_1d_lead0, reconstruction_ds_1d
Expand Down
12 changes: 6 additions & 6 deletions climpred/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import xarray as xr
from esmtools.stats import rm_poly

from climpred.stats import rm_poly
from climpred.testing import assert_PredictionEnsemble

xr.set_options(display_style="text")
Expand All @@ -16,7 +16,7 @@ def test_PredictionEnsemble_raises_error(hindcast_hist_obs_1d):
def test_PredictionEnsemble_raises_warning(hindcast_hist_obs_1d):
"""Tests that PredictionEnsemble raises warning."""
with pytest.warns(UserWarning):
hindcast_hist_obs_1d.map(rm_poly, dim="init", order=2)
hindcast_hist_obs_1d.map(rm_poly, dim="init", deg=2)


def test_PredictionEnsemble_xr_calls(hindcast_hist_obs_1d):
Expand All @@ -31,12 +31,12 @@ def test_PredictionEnsemble_xr_calls(hindcast_hist_obs_1d):
def test_PredictionEnsemble_map_dim_or(hindcast_hist_obs_1d):
"""Tests that PredictionEnsemble allows dim0_or_dim1 as kwargs without UserWarning."""
with pytest.warns(None): # no warnings
he_or = hindcast_hist_obs_1d.map(rm_poly, dim="init_or_time", order=2)
he_or = hindcast_hist_obs_1d.map(rm_poly, dim="init_or_time", deg=2)
assert he_or != hindcast_hist_obs_1d

with pytest.warns(UserWarning) as record: # triggers warnings
he_chained = hindcast_hist_obs_1d.map(rm_poly, dim="init", order=2).map(
rm_poly, dim="time", order=2
he_chained = hindcast_hist_obs_1d.map(rm_poly, dim="init", deg=2).map(
rm_poly, dim="time", deg=2
)
assert he_chained != hindcast_hist_obs_1d

Expand All @@ -52,4 +52,4 @@ def test_PredictionEnsemble_map_dim_or_fails_if_both_dims_in_dataset(
):
"""Tests that PredictionEnsemble with dim0_or_dim1 as kwargs fails if both dims in any dataset."""
with pytest.raises(ValueError, match="cannot be both in"):
hindcast_hist_obs_1d.map(rm_poly, dim="init_or_lead", order=2)
hindcast_hist_obs_1d.map(rm_poly, dim="init_or_lead", deg=2)
12 changes: 12 additions & 0 deletions climpred/tests/test_relative_entropy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest

from climpred.graphics import plot_relative_entropy
from climpred.relative_entropy import (
bootstrap_relative_entropy,
compute_relative_entropy,
)

try:
from eofs.xarray import Eof

Eof_loaded = True
except ImportError:
Eof_loaded = False


@pytest.mark.skipif(not Eof_loaded, reason="eofs not installed")
def test_compute_relative_entropy(PM_da_initialized_3d, PM_da_control_3d):
"""
Checks that there are no NaNs.
Expand All @@ -17,6 +27,7 @@ def test_compute_relative_entropy(PM_da_initialized_3d, PM_da_control_3d):
assert not actual_any_nan[var]


@pytest.mark.skipif(not Eof_loaded, reason="eofs not installed")
def test_bootstrap_relative_entropy(PM_da_initialized_3d, PM_da_control_3d):
"""
Checks that there are no NaNs.
Expand All @@ -33,6 +44,7 @@ def test_bootstrap_relative_entropy(PM_da_initialized_3d, PM_da_control_3d):
assert not actual_any_nan[var]


@pytest.mark.skipif(not Eof_loaded, reason="eofs not installed")
def test_plot_relative_entropy(PM_da_initialized_3d, PM_da_control_3d):
res = compute_relative_entropy(
PM_da_initialized_3d, PM_da_control_3d, nmember_control=5, neofs=2
Expand Down
15 changes: 13 additions & 2 deletions climpred/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
import xarray as xr
from xarray.testing import assert_allclose

from climpred.bootstrap import dpp_threshold, varweighted_mean_period_threshold
from climpred.stats import decorrelation_time, dpp, varweighted_mean_period
from climpred.bootstrap import dpp_threshold
from climpred.stats import decorrelation_time, dpp

try:
from climpred.bootstrap import varweighted_mean_period_threshold
from climpred.stats import varweighted_mean_period

xrft_loaded = True
except ImportError:
xrft_loaded = False

ITERATIONS = 2

Expand All @@ -16,6 +24,7 @@ def test_dpp(PM_da_control_3d, chunk):
assert res.mean() > 0


@pytest.mark.skipif(not xrft_loaded, reason="xrft not installed")
@pytest.mark.parametrize("func", (varweighted_mean_period, decorrelation_time))
def test_potential_predictability_likely(PM_da_control_3d, func):
"""Check for positive diagnostic potential predictability in NA SST."""
Expand All @@ -32,6 +41,7 @@ def test_bootstrap_dpp_sig50_similar_dpp(PM_da_control_3d):
xr.testing.assert_allclose(actual, expected, atol=0.5, rtol=0.5)


@pytest.mark.skipif(not xrft_loaded, reason="xrft not installed")
def test_bootstrap_vwmp_sig50_similar_vwmp(PM_da_control_3d):
sig = 50
actual = varweighted_mean_period_threshold(
Expand All @@ -48,6 +58,7 @@ def test_bootstrap_func_multiple_sig_levels(PM_da_control_3d):
assert (actual.isel(quantile=0).values <= actual.isel(quantile=1)).all()


@pytest.mark.skipif(not xrft_loaded, reason="xrft not installed")
@pytest.mark.parametrize("step", [1, 2, -1])
@pytest.mark.parametrize(
"func",
Expand Down
Loading

0 comments on commit 98ea010

Please sign in to comment.