Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make fit for xr=0.15.1 #348

Merged
merged 6 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ Internals/Minor Fixes
(:pr:`330`) `Aaron Spring`_.
- Refactoring :py:func:`~climpred.prediction.compute_hindcast` and
:py:func:`~climpred.prediction.compute_perfect_model`. (:pr:`330`) `Aaron Spring`_.
- Changed lead0 coordinate modifications to be compliant with xarray=0.15.1 in
:py:func:`~climpred.reference.compute_persistence`. (:pr:`348`) `Aaron Spring`_.
- Exchanged my_quantile with xr.quantile(skipna=False). (:pr:`348`) `Aaron Spring`_.


Documentation
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-dev-3.6.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- numpy
- pandas=0.25.3 # breaking features in pandas 1.0
- scipy
- xarray>=0.14.1 # Allows .drop_vars()
- xarray>=0.15.1
# Package Management
- asv
- black
Expand Down
40 changes: 7 additions & 33 deletions climpred/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_lead_cftime_shift_args,
get_metric_class,
lead_units_equal_control_time_stride,
rechunk_to_single_chunk_if_more_than_one_chunk_along_dim,
shift_cftime_singular,
)

Expand All @@ -50,35 +51,6 @@ def _resample(hind, resample_dim):
return smp_hind


def my_quantile(ds, q=0.95, dim='bootstrap'):
"""Compute quantile `q` faster than `xr.quantile` and allows lazy computation."""
# dim='bootstrap' doesnt lead anywhere, but want to keep xr.quantile API
# concat_dim is always first, therefore axis=0 implementation works in compute

def _dask_percentile(arr, axis=0, q=95):
"""Daskified np.percentile."""
if len(arr.chunks[axis]) > 1:
arr = arr.rechunk({axis: -1})
return dask.array.map_blocks(np.percentile, arr, axis=axis, q=q, drop_axis=axis)

def _percentile(arr, axis=0, q=95):
"""percentile function for chunked and non-chunked `arr`."""
if dask.is_dask_collection(arr):
return _dask_percentile(arr, axis=axis, q=q)
else:
return np.percentile(arr, axis=axis, q=q)

axis = 0
if not isinstance(q, list):
q = [q]
quantile = []
for qi in q:
quantile.append(ds.reduce(_percentile, q=qi * 100, axis=axis))
quantile = xr.concat(quantile, 'quantile')
quantile['quantile'] = q
return quantile.squeeze()


def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'):
"""Get confidence intervals from bootstrapped distribution.

Expand All @@ -93,8 +65,8 @@ def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'):
Returns:
uninit_hind (xarray object): uninitialize hindcast with hind.coords.
"""
# TODO: re-implement xr.quantile once fast
return my_quantile(ds, q=[ci_low, ci_high], dim='bootstrap')
ds = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim)
return ds.quantile(q=[ci_low, ci_high], dim=dim, skipna=False)


def _pvalue_from_distributions(simple_fct, init, metric=None):
Expand Down Expand Up @@ -325,8 +297,10 @@ def _bootstrap_func(
smp_init_pm = _resample(ds, resample_dim)
bootstraped_results.append(func(smp_init_pm, *func_args, **func_kwargs))
sig_level = xr.concat(bootstraped_results, dim='bootstrap', **CONCAT_KWARGS)
# TODO: reimplement xr.quantile once fast
sig_level = my_quantile(sig_level, dim='bootstrap', q=psig)
sig_level = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(
sig_level, dim='bootstrap'
)
sig_level = sig_level.quantile(dim='bootstrap', q=psig, skipna=False)
return sig_level


Expand Down
3 changes: 2 additions & 1 deletion climpred/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def compute_persistence(
# at lead 0 is == 1.
if [0] in hind.lead.values:
hind = hind.copy()
hind['lead'] += 1
with xr.set_options(keep_attrs=True): # keeps lead.attrs['units']
hind['lead'] = hind['lead'] + 1
n, freq = get_lead_cftime_shift_args(hind.lead.attrs['units'], 1)
# Shift backwards shift for lead zero.
hind['init'] = shift_cftime_index(hind, 'init', -1 * n, freq)
Expand Down
8 changes: 4 additions & 4 deletions climpred/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def PM_da_initialized_1d_lead0(PM_da_initialized_1d):
framework."""
da = PM_da_initialized_1d
# Convert to lead zero for testing
da['lead'] -= 1
da['init'] += 1
da['lead'] = da['lead'] - 1
da['init'] = da['init'] + 1
return da


Expand Down Expand Up @@ -132,8 +132,8 @@ def hind_ds_initialized_1d_lead0(hind_ds_initialized_1d):
framework."""
da = hind_ds_initialized_1d
# Change to a lead-0 framework
da['init'] += 1
da['lead'] -= 1
da['init'] = da['init'] + 1
da['lead'] = da['lead'] - 1
return da


Expand Down
49 changes: 3 additions & 46 deletions climpred/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import time

import dask
import numpy as np
import pytest
import xarray as xr
from xarray.testing import assert_allclose

from climpred.bootstrap import (
_bootstrap_by_stacking,
bootstrap_hindcast,
bootstrap_perfect_model,
bootstrap_uninit_pm_ensemble_from_control_cftime,
my_quantile,
)
from climpred.comparisons import HINDCAST_COMPARISONS, PM_COMPARISONS
from climpred.constants import CONCAT_KWARGS, VALID_ALIGNMENTS
Expand All @@ -21,45 +17,6 @@
BOOTSTRAP = 2


@pytest.mark.skip(reason='xr.quantile now faster')
@pytest.mark.parametrize('chunk', [True, False])
def test_dask_percentile_implemented_faster_xr_quantile(PM_da_control_3d, chunk):
"""Test my_quantile faster than xr.quantile.

TODO: Remove after xr=0.15.1 and add skipna=False.
"""
chunk_dim, dim = 'x', 'time'
if chunk:
chunks = {chunk_dim: 24}
PM_da_control_3d = PM_da_control_3d.chunk(chunks).persist()
start_time = time.time()
actual = my_quantile(PM_da_control_3d, q=0.95, dim=dim)
elapsed_time_my_quantile = time.time() - start_time

start_time = time.time()
expected = PM_da_control_3d.compute().quantile(q=0.95, dim=dim)
elapsed_time_xr_quantile = time.time() - start_time
if chunk:
assert dask.is_dask_collection(actual)
assert not dask.is_dask_collection(expected)
start_time = time.time()
actual = actual.compute()
elapsed_time_my_quantile = elapsed_time_my_quantile + time.time() - start_time
else:
assert not dask.is_dask_collection(actual)
assert not dask.is_dask_collection(expected)
assert actual.shape == expected.shape
assert_allclose(actual, expected)
print(
elapsed_time_my_quantile,
elapsed_time_xr_quantile,
'my_quantile is',
elapsed_time_xr_quantile / elapsed_time_my_quantile,
'times faster than xr.quantile',
)
assert elapsed_time_xr_quantile > elapsed_time_my_quantile


@pytest.mark.slow
@pytest.mark.parametrize('comparison', PM_COMPARISONS)
@pytest.mark.parametrize('chunk', [True, False])
Expand Down Expand Up @@ -118,7 +75,7 @@ def test_bootstrap_hindcast_lazy(
@pytest.mark.slow
@pytest.mark.parametrize('resample_dim', ['member', 'init'])
def test_bootstrap_hindcast_resample_dim(
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim,
):
"""Test bootstrap_hindcast when resampling member or init and alignment
same_inits."""
Expand Down Expand Up @@ -291,7 +248,7 @@ def test_bootstrap_by_stacking_chunked(
PM_ds_initialized_1d_ym_cftime, PM_ds_control_1d_ym_cftime
):
res_chunked = _bootstrap_by_stacking(
PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk()
PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk(),
)
assert dask.is_dask_collection(res_chunked)
res_chunked = res_chunked.compute()
Expand Down Expand Up @@ -327,7 +284,7 @@ def test_bootstrap_by_stacking_two_var_dataset(
@pytest.mark.slow
@pytest.mark.parametrize('alignment', VALID_ALIGNMENTS)
def test_bootstrap_hindcast_alignment(
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment,
):
"""Test bootstrap_hindcast for all alginments when resampling member."""
bootstrap_hindcast(
Expand Down
12 changes: 12 additions & 0 deletions climpred/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,18 @@ def _load_into_memory(res):
return res


def rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim):
"""Rechunk an xarray object more than one chunk along dim."""
if dask.is_dask_collection(ds):
if isinstance(ds, xr.Dataset):
nchunks = len(ds.chunks[dim])
elif isinstance(ds, xr.DataArray):
nchunks = len(ds.chunks[ds.get_axis_num(dim)])
if nchunks > 1:
ds = ds.chunk({dim: -1})
return ds


def shift_cftime_index(xobj, time_string, n, freq):
"""Shifts a ``CFTimeIndex`` over a specified number of time steps at a given
temporal frequency.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy
pandas
scipy
tqdm
xarray>=0.14.1
xarray>=0.15.1
xrft
xskillscore>=0.0.12