Skip to content

Commit

Permalink
Use xr.quantile with new v0.15.1 of xarray (#348)
Browse files Browse the repository at this point in the history
* Reimplements `xr.quantile` to replace `my_quantile` using the new `skipna` feature.
* Updates requirements to require xarray >= 0.15.1 for developers and users.
  • Loading branch information
aaronspring authored Apr 15, 2020
1 parent 8be0a89 commit 583fba7
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 86 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ Internals/Minor Fixes
(:pr:`330`) `Aaron Spring`_.
- Refactoring :py:func:`~climpred.prediction.compute_hindcast` and
:py:func:`~climpred.prediction.compute_perfect_model`. (:pr:`330`) `Aaron Spring`_.
- Changed lead0 coordinate modifications to be compliant with xarray=0.15.1 in
:py:func:`~climpred.reference.compute_persistence`. (:pr:`348`) `Aaron Spring`_.
- Exchanged my_quantile with xr.quantile(skipna=False). (:pr:`348`) `Aaron Spring`_.


Documentation
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-dev-3.6.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- numpy
- pandas=0.25.3 # breaking features in pandas 1.0
- scipy
- xarray>=0.14.1 # Allows .drop_vars()
- xarray>=0.15.1
# Package Management
- asv
- black
Expand Down
40 changes: 7 additions & 33 deletions climpred/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_lead_cftime_shift_args,
get_metric_class,
lead_units_equal_control_time_stride,
rechunk_to_single_chunk_if_more_than_one_chunk_along_dim,
shift_cftime_singular,
)

Expand All @@ -50,35 +51,6 @@ def _resample(hind, resample_dim):
return smp_hind


def my_quantile(ds, q=0.95, dim='bootstrap'):
"""Compute quantile `q` faster than `xr.quantile` and allows lazy computation."""
# dim='bootstrap' doesnt lead anywhere, but want to keep xr.quantile API
# concat_dim is always first, therefore axis=0 implementation works in compute

def _dask_percentile(arr, axis=0, q=95):
"""Daskified np.percentile."""
if len(arr.chunks[axis]) > 1:
arr = arr.rechunk({axis: -1})
return dask.array.map_blocks(np.percentile, arr, axis=axis, q=q, drop_axis=axis)

def _percentile(arr, axis=0, q=95):
"""percentile function for chunked and non-chunked `arr`."""
if dask.is_dask_collection(arr):
return _dask_percentile(arr, axis=axis, q=q)
else:
return np.percentile(arr, axis=axis, q=q)

axis = 0
if not isinstance(q, list):
q = [q]
quantile = []
for qi in q:
quantile.append(ds.reduce(_percentile, q=qi * 100, axis=axis))
quantile = xr.concat(quantile, 'quantile')
quantile['quantile'] = q
return quantile.squeeze()


def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'):
"""Get confidence intervals from bootstrapped distribution.
Expand All @@ -93,8 +65,8 @@ def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'):
Returns:
uninit_hind (xarray object): uninitialize hindcast with hind.coords.
"""
# TODO: re-implement xr.quantile once fast
return my_quantile(ds, q=[ci_low, ci_high], dim='bootstrap')
ds = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim)
return ds.quantile(q=[ci_low, ci_high], dim=dim, skipna=False)


def _pvalue_from_distributions(simple_fct, init, metric=None):
Expand Down Expand Up @@ -325,8 +297,10 @@ def _bootstrap_func(
smp_init_pm = _resample(ds, resample_dim)
bootstraped_results.append(func(smp_init_pm, *func_args, **func_kwargs))
sig_level = xr.concat(bootstraped_results, dim='bootstrap', **CONCAT_KWARGS)
# TODO: reimplement xr.quantile once fast
sig_level = my_quantile(sig_level, dim='bootstrap', q=psig)
sig_level = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(
sig_level, dim='bootstrap'
)
sig_level = sig_level.quantile(dim='bootstrap', q=psig, skipna=False)
return sig_level


Expand Down
3 changes: 2 additions & 1 deletion climpred/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def compute_persistence(
# at lead 0 is == 1.
if [0] in hind.lead.values:
hind = hind.copy()
hind['lead'] += 1
with xr.set_options(keep_attrs=True): # keeps lead.attrs['units']
hind['lead'] = hind['lead'] + 1
n, freq = get_lead_cftime_shift_args(hind.lead.attrs['units'], 1)
# Shift backwards shift for lead zero.
hind['init'] = shift_cftime_index(hind, 'init', -1 * n, freq)
Expand Down
8 changes: 4 additions & 4 deletions climpred/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def PM_da_initialized_1d_lead0(PM_da_initialized_1d):
framework."""
da = PM_da_initialized_1d
# Convert to lead zero for testing
da['lead'] -= 1
da['init'] += 1
da['lead'] = da['lead'] - 1
da['init'] = da['init'] + 1
return da


Expand Down Expand Up @@ -132,8 +132,8 @@ def hind_ds_initialized_1d_lead0(hind_ds_initialized_1d):
framework."""
da = hind_ds_initialized_1d
# Change to a lead-0 framework
da['init'] += 1
da['lead'] -= 1
da['init'] = da['init'] + 1
da['lead'] = da['lead'] - 1
return da


Expand Down
49 changes: 3 additions & 46 deletions climpred/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import time

import dask
import numpy as np
import pytest
import xarray as xr
from xarray.testing import assert_allclose

from climpred.bootstrap import (
_bootstrap_by_stacking,
bootstrap_hindcast,
bootstrap_perfect_model,
bootstrap_uninit_pm_ensemble_from_control_cftime,
my_quantile,
)
from climpred.comparisons import HINDCAST_COMPARISONS, PM_COMPARISONS
from climpred.constants import CONCAT_KWARGS, VALID_ALIGNMENTS
Expand All @@ -21,45 +17,6 @@
BOOTSTRAP = 2


@pytest.mark.skip(reason='xr.quantile now faster')
@pytest.mark.parametrize('chunk', [True, False])
def test_dask_percentile_implemented_faster_xr_quantile(PM_da_control_3d, chunk):
"""Test my_quantile faster than xr.quantile.
TODO: Remove after xr=0.15.1 and add skipna=False.
"""
chunk_dim, dim = 'x', 'time'
if chunk:
chunks = {chunk_dim: 24}
PM_da_control_3d = PM_da_control_3d.chunk(chunks).persist()
start_time = time.time()
actual = my_quantile(PM_da_control_3d, q=0.95, dim=dim)
elapsed_time_my_quantile = time.time() - start_time

start_time = time.time()
expected = PM_da_control_3d.compute().quantile(q=0.95, dim=dim)
elapsed_time_xr_quantile = time.time() - start_time
if chunk:
assert dask.is_dask_collection(actual)
assert not dask.is_dask_collection(expected)
start_time = time.time()
actual = actual.compute()
elapsed_time_my_quantile = elapsed_time_my_quantile + time.time() - start_time
else:
assert not dask.is_dask_collection(actual)
assert not dask.is_dask_collection(expected)
assert actual.shape == expected.shape
assert_allclose(actual, expected)
print(
elapsed_time_my_quantile,
elapsed_time_xr_quantile,
'my_quantile is',
elapsed_time_xr_quantile / elapsed_time_my_quantile,
'times faster than xr.quantile',
)
assert elapsed_time_xr_quantile > elapsed_time_my_quantile


@pytest.mark.slow
@pytest.mark.parametrize('comparison', PM_COMPARISONS)
@pytest.mark.parametrize('chunk', [True, False])
Expand Down Expand Up @@ -118,7 +75,7 @@ def test_bootstrap_hindcast_lazy(
@pytest.mark.slow
@pytest.mark.parametrize('resample_dim', ['member', 'init'])
def test_bootstrap_hindcast_resample_dim(
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim,
):
"""Test bootstrap_hindcast when resampling member or init and alignment
same_inits."""
Expand Down Expand Up @@ -291,7 +248,7 @@ def test_bootstrap_by_stacking_chunked(
PM_ds_initialized_1d_ym_cftime, PM_ds_control_1d_ym_cftime
):
res_chunked = _bootstrap_by_stacking(
PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk()
PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk(),
)
assert dask.is_dask_collection(res_chunked)
res_chunked = res_chunked.compute()
Expand Down Expand Up @@ -327,7 +284,7 @@ def test_bootstrap_by_stacking_two_var_dataset(
@pytest.mark.slow
@pytest.mark.parametrize('alignment', VALID_ALIGNMENTS)
def test_bootstrap_hindcast_alignment(
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment
hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment,
):
"""Test bootstrap_hindcast for all alginments when resampling member."""
bootstrap_hindcast(
Expand Down
12 changes: 12 additions & 0 deletions climpred/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,18 @@ def _load_into_memory(res):
return res


def rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim):
"""Rechunk an xarray object more than one chunk along dim."""
if dask.is_dask_collection(ds):
if isinstance(ds, xr.Dataset):
nchunks = len(ds.chunks[dim])
elif isinstance(ds, xr.DataArray):
nchunks = len(ds.chunks[ds.get_axis_num(dim)])
if nchunks > 1:
ds = ds.chunk({dim: -1})
return ds


def shift_cftime_index(xobj, time_string, n, freq):
"""Shifts a ``CFTimeIndex`` over a specified number of time steps at a given
temporal frequency.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy
pandas
scipy
tqdm
xarray>=0.14.1
xarray>=0.15.1
xrft
xskillscore>=0.0.12

0 comments on commit 583fba7

Please sign in to comment.