From 583fba7456328f9265b3a206e4c1b81acf781de3 Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Wed, 15 Apr 2020 20:30:50 +0200 Subject: [PATCH] Use xr.quantile with new v0.15.1 of xarray (#348) * Reimplements `xr.quantile` to replace `my_quantile` using the new `skipna` feature. * Updates requirements to require xarray >= 0.15.1 for developers and users. --- CHANGELOG.rst | 3 ++ ci/environment-dev-3.6.yml | 2 +- climpred/bootstrap.py | 40 +++++--------------------- climpred/reference.py | 3 +- climpred/tests/conftest.py | 8 +++--- climpred/tests/test_bootstrap.py | 49 ++------------------------------ climpred/utils.py | 12 ++++++++ requirements.txt | 2 +- 8 files changed, 33 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b9c22f51c..fd09a805a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -59,6 +59,9 @@ Internals/Minor Fixes (:pr:`330`) `Aaron Spring`_. - Refactoring :py:func:`~climpred.prediction.compute_hindcast` and :py:func:`~climpred.prediction.compute_perfect_model`. (:pr:`330`) `Aaron Spring`_. +- Changed lead0 coordinate modifications to be compliant with xarray=0.15.1 in + :py:func:`~climpred.reference.compute_persistence`. (:pr:`348`) `Aaron Spring`_. +- Exchanged my_quantile with xr.quantile(skipna=False). (:pr:`348`) `Aaron Spring`_. Documentation diff --git a/ci/environment-dev-3.6.yml b/ci/environment-dev-3.6.yml index 5fdedbe7b..e501225ff 100644 --- a/ci/environment-dev-3.6.yml +++ b/ci/environment-dev-3.6.yml @@ -22,7 +22,7 @@ dependencies: - numpy - pandas=0.25.3 # breaking features in pandas 1.0 - scipy - - xarray>=0.14.1 # Allows .drop_vars() + - xarray>=0.15.1 # Package Management - asv - black diff --git a/climpred/bootstrap.py b/climpred/bootstrap.py index d7d52955a..cfa2aa5ae 100644 --- a/climpred/bootstrap.py +++ b/climpred/bootstrap.py @@ -26,6 +26,7 @@ get_lead_cftime_shift_args, get_metric_class, lead_units_equal_control_time_stride, + rechunk_to_single_chunk_if_more_than_one_chunk_along_dim, shift_cftime_singular, ) @@ -50,35 +51,6 @@ def _resample(hind, resample_dim): return smp_hind -def my_quantile(ds, q=0.95, dim='bootstrap'): - """Compute quantile `q` faster than `xr.quantile` and allows lazy computation.""" - # dim='bootstrap' doesnt lead anywhere, but want to keep xr.quantile API - # concat_dim is always first, therefore axis=0 implementation works in compute - - def _dask_percentile(arr, axis=0, q=95): - """Daskified np.percentile.""" - if len(arr.chunks[axis]) > 1: - arr = arr.rechunk({axis: -1}) - return dask.array.map_blocks(np.percentile, arr, axis=axis, q=q, drop_axis=axis) - - def _percentile(arr, axis=0, q=95): - """percentile function for chunked and non-chunked `arr`.""" - if dask.is_dask_collection(arr): - return _dask_percentile(arr, axis=axis, q=q) - else: - return np.percentile(arr, axis=axis, q=q) - - axis = 0 - if not isinstance(q, list): - q = [q] - quantile = [] - for qi in q: - quantile.append(ds.reduce(_percentile, q=qi * 100, axis=axis)) - quantile = xr.concat(quantile, 'quantile') - quantile['quantile'] = q - return quantile.squeeze() - - def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'): """Get confidence intervals from bootstrapped distribution. @@ -93,8 +65,8 @@ def _distribution_to_ci(ds, ci_low, ci_high, dim='bootstrap'): Returns: uninit_hind (xarray object): uninitialize hindcast with hind.coords. """ - # TODO: re-implement xr.quantile once fast - return my_quantile(ds, q=[ci_low, ci_high], dim='bootstrap') + ds = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim) + return ds.quantile(q=[ci_low, ci_high], dim=dim, skipna=False) def _pvalue_from_distributions(simple_fct, init, metric=None): @@ -325,8 +297,10 @@ def _bootstrap_func( smp_init_pm = _resample(ds, resample_dim) bootstraped_results.append(func(smp_init_pm, *func_args, **func_kwargs)) sig_level = xr.concat(bootstraped_results, dim='bootstrap', **CONCAT_KWARGS) - # TODO: reimplement xr.quantile once fast - sig_level = my_quantile(sig_level, dim='bootstrap', q=psig) + sig_level = rechunk_to_single_chunk_if_more_than_one_chunk_along_dim( + sig_level, dim='bootstrap' + ) + sig_level = sig_level.quantile(dim='bootstrap', q=psig, skipna=False) return sig_level diff --git a/climpred/reference.py b/climpred/reference.py index 750a57d8c..dfbc29290 100644 --- a/climpred/reference.py +++ b/climpred/reference.py @@ -89,7 +89,8 @@ def compute_persistence( # at lead 0 is == 1. if [0] in hind.lead.values: hind = hind.copy() - hind['lead'] += 1 + with xr.set_options(keep_attrs=True): # keeps lead.attrs['units'] + hind['lead'] = hind['lead'] + 1 n, freq = get_lead_cftime_shift_args(hind.lead.attrs['units'], 1) # Shift backwards shift for lead zero. hind['init'] = shift_cftime_index(hind, 'init', -1 * n, freq) diff --git a/climpred/tests/conftest.py b/climpred/tests/conftest.py index f269e8252..b748664b5 100644 --- a/climpred/tests/conftest.py +++ b/climpred/tests/conftest.py @@ -30,8 +30,8 @@ def PM_da_initialized_1d_lead0(PM_da_initialized_1d): framework.""" da = PM_da_initialized_1d # Convert to lead zero for testing - da['lead'] -= 1 - da['init'] += 1 + da['lead'] = da['lead'] - 1 + da['init'] = da['init'] + 1 return da @@ -132,8 +132,8 @@ def hind_ds_initialized_1d_lead0(hind_ds_initialized_1d): framework.""" da = hind_ds_initialized_1d # Change to a lead-0 framework - da['init'] += 1 - da['lead'] -= 1 + da['init'] = da['init'] + 1 + da['lead'] = da['lead'] - 1 return da diff --git a/climpred/tests/test_bootstrap.py b/climpred/tests/test_bootstrap.py index 483658725..658565a6a 100644 --- a/climpred/tests/test_bootstrap.py +++ b/climpred/tests/test_bootstrap.py @@ -1,17 +1,13 @@ -import time - import dask import numpy as np import pytest import xarray as xr -from xarray.testing import assert_allclose from climpred.bootstrap import ( _bootstrap_by_stacking, bootstrap_hindcast, bootstrap_perfect_model, bootstrap_uninit_pm_ensemble_from_control_cftime, - my_quantile, ) from climpred.comparisons import HINDCAST_COMPARISONS, PM_COMPARISONS from climpred.constants import CONCAT_KWARGS, VALID_ALIGNMENTS @@ -21,45 +17,6 @@ BOOTSTRAP = 2 -@pytest.mark.skip(reason='xr.quantile now faster') -@pytest.mark.parametrize('chunk', [True, False]) -def test_dask_percentile_implemented_faster_xr_quantile(PM_da_control_3d, chunk): - """Test my_quantile faster than xr.quantile. - - TODO: Remove after xr=0.15.1 and add skipna=False. - """ - chunk_dim, dim = 'x', 'time' - if chunk: - chunks = {chunk_dim: 24} - PM_da_control_3d = PM_da_control_3d.chunk(chunks).persist() - start_time = time.time() - actual = my_quantile(PM_da_control_3d, q=0.95, dim=dim) - elapsed_time_my_quantile = time.time() - start_time - - start_time = time.time() - expected = PM_da_control_3d.compute().quantile(q=0.95, dim=dim) - elapsed_time_xr_quantile = time.time() - start_time - if chunk: - assert dask.is_dask_collection(actual) - assert not dask.is_dask_collection(expected) - start_time = time.time() - actual = actual.compute() - elapsed_time_my_quantile = elapsed_time_my_quantile + time.time() - start_time - else: - assert not dask.is_dask_collection(actual) - assert not dask.is_dask_collection(expected) - assert actual.shape == expected.shape - assert_allclose(actual, expected) - print( - elapsed_time_my_quantile, - elapsed_time_xr_quantile, - 'my_quantile is', - elapsed_time_xr_quantile / elapsed_time_my_quantile, - 'times faster than xr.quantile', - ) - assert elapsed_time_xr_quantile > elapsed_time_my_quantile - - @pytest.mark.slow @pytest.mark.parametrize('comparison', PM_COMPARISONS) @pytest.mark.parametrize('chunk', [True, False]) @@ -118,7 +75,7 @@ def test_bootstrap_hindcast_lazy( @pytest.mark.slow @pytest.mark.parametrize('resample_dim', ['member', 'init']) def test_bootstrap_hindcast_resample_dim( - hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim + hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, resample_dim, ): """Test bootstrap_hindcast when resampling member or init and alignment same_inits.""" @@ -291,7 +248,7 @@ def test_bootstrap_by_stacking_chunked( PM_ds_initialized_1d_ym_cftime, PM_ds_control_1d_ym_cftime ): res_chunked = _bootstrap_by_stacking( - PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk() + PM_ds_initialized_1d_ym_cftime.chunk(), PM_ds_control_1d_ym_cftime.chunk(), ) assert dask.is_dask_collection(res_chunked) res_chunked = res_chunked.compute() @@ -327,7 +284,7 @@ def test_bootstrap_by_stacking_two_var_dataset( @pytest.mark.slow @pytest.mark.parametrize('alignment', VALID_ALIGNMENTS) def test_bootstrap_hindcast_alignment( - hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment + hind_da_initialized_1d, hist_da_uninitialized_1d, observations_da_1d, alignment, ): """Test bootstrap_hindcast for all alginments when resampling member.""" bootstrap_hindcast( diff --git a/climpred/utils.py b/climpred/utils.py index 48bcd7577..2c3ef8bcf 100644 --- a/climpred/utils.py +++ b/climpred/utils.py @@ -370,6 +370,18 @@ def _load_into_memory(res): return res +def rechunk_to_single_chunk_if_more_than_one_chunk_along_dim(ds, dim): + """Rechunk an xarray object more than one chunk along dim.""" + if dask.is_dask_collection(ds): + if isinstance(ds, xr.Dataset): + nchunks = len(ds.chunks[dim]) + elif isinstance(ds, xr.DataArray): + nchunks = len(ds.chunks[ds.get_axis_num(dim)]) + if nchunks > 1: + ds = ds.chunk({dim: -1}) + return ds + + def shift_cftime_index(xobj, time_string, n, freq): """Shifts a ``CFTimeIndex`` over a specified number of time steps at a given temporal frequency. diff --git a/requirements.txt b/requirements.txt index d8e1aaf74..ff650670b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ numpy pandas scipy tqdm -xarray>=0.14.1 +xarray>=0.15.1 xrft xskillscore>=0.0.12