Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resample idx #354

Merged
merged 11 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Internals/Minor Fixes
- Remove ``sig`` from
:py:func:`~climpred.graphics.plot_bootstrapped_skill_over_leadyear`.
(:pr:`351`) `Aaron Spring`_.

- Faster bootstrapping without replacement used in threshold functions of
climpred.stats (:pr:`354`) `Aaron Spring`_.


Documentation
Expand Down
51 changes: 43 additions & 8 deletions climpred/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,44 @@ def _resample(hind, resample_dim):
return smp_hind


def _resample_idx(init, bootstrap, dim='member', replace=True):
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
"""Resample bootstrap times over dim."""
aaronspring marked this conversation as resolved.
Show resolved Hide resolved

def xr_broadcast(x, idx):
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
return np.moveaxis(x.squeeze()[idx.squeeze().transpose()], 0, -1)

# I get errors if I have a 0 in dim
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
if 0 in init[dim].values:
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
init[dim] = np.arange(1, 1 + init[dim].size)
# resample with or without replacement
if replace:
idx = np.random.randint(0, init[dim].size, (bootstrap, init[dim].size))
elif not replace:
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
# create 2d np.arange()
idx = np.linspace(
(np.arange(init[dim].size)),
(np.arange(init[dim].size)),
bootstrap,
dtype='int',
)
# shuffle each line
for ndx in np.arange(bootstrap):
np.random.shuffle(idx[ndx])
idx_da = xr.DataArray(
idx,
dims=('bootstrap', dim),
coords=({'bootstrap': range(bootstrap), dim: init[dim]}),
)

return xr.apply_ufunc(
xr_broadcast,
init.transpose(dim, ...),
idx_da,
dask='parallelized',
output_dtypes=[float],
)


def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'):
"""Get confidence intervals from bootstrapped distribution.

Expand Down Expand Up @@ -292,15 +330,12 @@ def _bootstrap_func(
else:
psig = sig / 100

bootstraped_results = []
for _ in range(bootstrap):
smp_init_pm = _resample(ds, resample_dim)
bootstraped_results.append(func(smp_init_pm, *func_args, **func_kwargs))
sig_level = xr.concat(bootstraped_results, dim='bootstrap', **CONCAT_KWARGS)
sig_level = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(
sig_level, dim='bootstrap'
bootstraped_ds = _resample_idx(ds, bootstrap, dim=resample_dim, replace=False)
bootstraped_results = func(bootstraped_ds, *func_args, **func_kwargs)
bootstraped_results = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(
bootstraped_results, dim='bootstrap'
)
sig_level = sig_level.quantile(dim='bootstrap', q=psig, skipna=False)
sig_level = bootstraped_results.quantile(dim='bootstrap', q=psig, skipna=False)
return sig_level


Expand Down
24 changes: 24 additions & 0 deletions climpred/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from climpred.bootstrap import (
_bootstrap_by_stacking,
_resample,
_resample_idx,
bootstrap_hindcast,
bootstrap_perfect_model,
bootstrap_uninit_pm_ensemble_from_control_cftime,
Expand Down Expand Up @@ -331,3 +333,25 @@ def test_bootstrap_hindcast_raises_error(
resample_dim='init',
alignment='same_verifs',
)


def test_resample_1_size(PM_da_initialized_1d):
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
dim = 'member'
expected = _resample(PM_da_initialized_1d, resample_dim=dim)
# 1 somehow fails
actual = _resample_idx(PM_da_initialized_1d, 2, dim=dim).isel(bootstrap=0)
assert expected.size == actual.size
assert expected[dim].size == actual[dim].size


def test_resample_size(PM_da_initialized_1d):
dim = 'member'
expected = xr.concat(
[_resample(PM_da_initialized_1d, resample_dim=dim) for i in range(BOOTSTRAP)],
'bootstrap',
)
actual = _resample_idx(PM_da_initialized_1d, BOOTSTRAP, dim=dim)
print(actual.shape, expected.shape)
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
print(actual.dims, expected.dims)
assert expected.size == actual.size
assert expected[dim].size == actual[dim].size
28 changes: 12 additions & 16 deletions climpred/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
varweighted_mean_period,
)

BOOTSTRAP = 5


def test_rm_trend_missing_dim():
with pytest.raises(DimensionError) as excinfo:
Expand Down Expand Up @@ -86,8 +88,7 @@ def test_rm_trend_3d_dataset_dim_order(multi_dim_ds):
@pytest.mark.parametrize('chunk', (True, False))
def test_dpp(PM_da_control_3d, chunk):
"""Check for positive diagnostic potential predictability in NA SST."""
control = PM_da_control_3d
res = dpp(control, chunk=chunk)
res = dpp(PM_da_control_3d, chunk=chunk)
assert res.mean() > 0


Expand All @@ -96,9 +97,8 @@ def test_dpp(PM_da_control_3d, chunk):
)
def test_potential_predictability_likely(PM_da_control_3d, func):
"""Check for positive diagnostic potential predictability in NA SST."""
control = PM_da_control_3d
print(control.dims)
res = func(control)
print(PM_da_control_3d.dims)
aaronspring marked this conversation as resolved.
Show resolved Hide resolved
res = func(PM_da_control_3d)
assert res.mean() > 0


Expand All @@ -120,30 +120,26 @@ def test_corr(PM_da_control_3d):


def test_bootstrap_dpp_sig50_similar_dpp(PM_da_control_3d):
ds = PM_da_control_3d
bootstrap = 5
sig = 50
actual = dpp_threshold(ds, bootstrap=bootstrap, sig=sig).drop_vars('quantile')
expected = dpp(ds)
actual = dpp_threshold(PM_da_control_3d, bootstrap=BOOTSTRAP, sig=sig).drop_vars(
'quantile'
)
expected = dpp(PM_da_control_3d)
xr.testing.assert_allclose(actual, expected, atol=0.5, rtol=0.5)


def test_bootstrap_vwmp_sig50_similar_vwmp(PM_da_control_3d):
ds = PM_da_control_3d
bootstrap = 5
sig = 50
actual = varweighted_mean_period_threshold(
ds, bootstrap=bootstrap, sig=sig
PM_da_control_3d, bootstrap=BOOTSTRAP, sig=sig
).drop_vars('quantile')
expected = varweighted_mean_period(ds)
expected = varweighted_mean_period(PM_da_control_3d)
xr.testing.assert_allclose(actual, expected, atol=2, rtol=0.5)


def test_bootstrap_func_multiple_sig_levels(PM_da_control_3d):
ds = PM_da_control_3d
bootstrap = 5
sig = [5, 95]
actual = dpp_threshold(ds, bootstrap=bootstrap, sig=sig)
actual = dpp_threshold(PM_da_control_3d, bootstrap=BOOTSTRAP, sig=sig)
assert actual['quantile'].size == len(sig)
assert (actual.isel(quantile=0).values <= actual.isel(quantile=1)).all()

Expand Down