Skip to content

Commit

Permalink
remove variadic args from block bootstrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Jan 30, 2025
1 parent cf9bb11 commit 9078c2d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
49 changes: 25 additions & 24 deletions src/scores/processing/block_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
from collections import OrderedDict
from itertools import chain, cycle, islice
from typing import Dict, List, Tuple, Union
from typing import Dict, Iterable, List, Tuple, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -158,7 +158,7 @@ def _bootstrap(*arrays: np.ndarray, indices: List[np.ndarray]) -> Union[np.ndarr


def _block_bootstrap( # pylint: disable=too-many-locals
*arrays: XarrayLike,
array_list: List[XarrayLike],
blocks: Dict[str, int],
n_iteration: int,
exclude_dims: Union[List[List[str]], None] = None,
Expand All @@ -171,7 +171,7 @@ def _block_bootstrap( # pylint: disable=too-many-locals
along that dimension, the second provided dimension is bootstrapped, and so forth.
Args:
*arrays: Data to bootstrap. Multiple datasets can be passed to be bootstrapped
array_list: Data to bootstrap. Multiple arrays can be passed to be bootstrapped
in the same way. All input arrays must have nested dimensions.
blocks: Dictionary of dimension(s) to bootstrap and the block sizes to use
along each dimension: ``{dim: blocksize}``. Nesting is based on the order of
Expand Down Expand Up @@ -201,22 +201,19 @@ def _block_bootstrap( # pylint: disable=too-many-locals
Wilks, Daniel S. Statistical methods in the atmospheric sciences. Vol. 100.
Academic press, 2011.
"""

arrays_list = list(arrays)

# Rename exclude_dims so they are not bootstrapped
if exclude_dims is None:
exclude_dims = [[] for _ in range(len(arrays_list))]
exclude_dims = [[] for _ in range(len(array_list))]
if not isinstance(exclude_dims, list) or not all(isinstance(x, list) for x in exclude_dims):
raise ValueError("exclude_dims should be a list of lists")
if len(exclude_dims) != len(arrays_list):
if len(exclude_dims) != len(array_list):
raise ValueError(
"exclude_dims should be a list of the same length as the number of "
"arrays containing lists of dimensions to exclude for each array"
)
renames = []
for i, (obj, exclude) in enumerate(zip(arrays_list, exclude_dims)):
arrays_list[i] = obj.rename(
for i, (obj, exclude) in enumerate(zip(array_list, exclude_dims)):
array_list[i] = obj.rename(
{d: f"dim{ii}" for ii, d in enumerate(exclude)},
)
renames.append({f"dim{ii}": d for ii, d in enumerate(exclude)})
Expand All @@ -225,7 +222,7 @@ def _block_bootstrap( # pylint: disable=too-many-locals

# Ensure bootstrapped dimensions have consistent sizes across arrays_list
for d in blocks.keys():
dim_sizes = [o.sizes[d] for o in arrays_list if d in o.dims]
dim_sizes = [o.sizes[d] for o in array_list if d in o.dims]
if not all(s == dim_sizes[0] for s in dim_sizes):
raise ValueError(f"Block dimension {d} is not the same size on all input arrays")

Expand All @@ -234,7 +231,7 @@ def _block_bootstrap( # pylint: disable=too-many-locals
try:
sizes = next(
OrderedDict([(d, (obj.sizes[d], b)) for d, b in blocks.items()])
for obj in arrays_list
for obj in array_list
if all(d in obj.sizes for d in blocks)
)
except StopIteration:
Expand All @@ -248,7 +245,7 @@ def _block_bootstrap( # pylint: disable=too-many-locals
# Expand indices for broadcasting for each array separately
indices = []
input_core_dims = []
for obj in arrays_list:
for obj in array_list:
available_dims = [d for d in dim if d in obj.dims]
indices_to_expand = [nested_indices[key] for key in available_dims]

Expand All @@ -257,7 +254,7 @@ def _block_bootstrap( # pylint: disable=too-many-locals

# Process arrays_list separately to handle non-matching dimensions
result = []
for obj, ind, core_dims in zip(arrays_list, indices, input_core_dims):
for obj, ind, core_dims in zip(array_list, indices, input_core_dims):
if isinstance(obj, xr.Dataset):
# Assume all variables have the same dtype
output_dtype = obj[list(obj.data_vars)[0]].dtype
Expand All @@ -282,7 +279,8 @@ def _block_bootstrap( # pylint: disable=too-many-locals


def block_bootstrap(
*arrays: XarrayLike,
array_list: List[XarrayLike] | XarrayLike,
*, # Enforce keyword-only arguments
blocks: Dict[str, int],
n_iteration: int,
exclude_dims: Union[List[List[str]], None] = None,
Expand All @@ -295,7 +293,8 @@ def block_bootstrap(
handling Dask arrays for chunk size limitation.
Args:
arrays: The data to bootstrap, which can be multiple datasets. In the case where
array_list: The data to bootstrap, which can be a single xarray object or
a list of multiple xarray objects. In the case where
multiple datasets are passed, each dataset can have its own set of dimension. However,
for successful bootstrapping, dimensions across all input arrays must be nested.
For instance, for ``block.keys=['d1', 'd2', 'd3']``, an array with dimension 'd1' and
Expand All @@ -304,12 +303,12 @@ def block_bootstrap(
blocks: A dictionary specifying the dimension(s) to bootstrap and the block sizes to
use along each dimension: ``{dimension: block_size}``. The keys represent the dimensions
to be bootstrapped, and the values indicate the block sizes along each dimension.
The dimension provided here should exist in the data provided as ``arrays``.
The dimension provided here should exist in the data provided in ``array_list``.
n_iteration: The number of iterations to repeat the bootstrapping process. Determines
how many bootstrapped arrays will be generated and stacked along the iteration
dimension.
exclude_dims: An optional parameter indicating the dimensions to be excluded during
bootstrapping for each array provided in ``arrays``. This parameter expects a list
bootstrapping for each array provided in ``array_list``. This parameter expects a list
of lists, where each inner list corresponds to the dimensions to be excluded for
the respective array. By default, the assumption is that no dimensions are
excluded, and all arrays are bootstrapped across all specified dimensions in ``blocks``.
Expand All @@ -327,12 +326,12 @@ def block_bootstrap(
Wilks, Daniel S. Statistical methods in the atmospheric sciences. Vol. 100.
Academic press, 2011.
"""

# While the most efficient method involves expanding the iteration dimension withing the
# universal function, this approach might generate excessively large chunks (resulting
# from multiplying chunk size by iterations) leading to issues with large numbers of
# iterations. Hence, here function loops over blocks of iterations to generate the total
# number of iterations.

def _max_chunk_size_mb(ds):
"""
Get the max chunk size in a dataset
Expand All @@ -348,11 +347,13 @@ def _max_chunk_size_mb(ds):
chunks.append(size_of_chunk)
return max(chunks)

if not isinstance(array_list, List):
array_list = [array_list]
# Choose iteration blocks to limit chunk size on dask arrays
if arrays[0].chunks: # Note: This is a way to check if the array is backed by a dask.array
if array_list[0].chunks: # Note: This is a way to check if the array is backed by a dask.array
# without loading data into memory.
# See https://docs.xarray.dev/en/stable/generated/xarray.DataArray.chunks.html
ds_max_chunk_size_mb = max(_max_chunk_size_mb(obj) for obj in arrays)
ds_max_chunk_size_mb = max(_max_chunk_size_mb(obj) for obj in array_list)
blocksize = int(MAX_CHUNK_SIZE_MB / ds_max_chunk_size_mb)
blocksize = min(blocksize, n_iteration)
blocksize = max(blocksize, 1)
Expand All @@ -363,7 +364,7 @@ def _max_chunk_size_mb(ds):
for _ in range(blocksize, n_iteration + 1, blocksize):
bootstraps.append(
_block_bootstrap(
*arrays,
array_list,
blocks=blocks,
n_iteration=blocksize,
exclude_dims=exclude_dims,
Expand All @@ -375,7 +376,7 @@ def _max_chunk_size_mb(ds):
if leftover:
bootstraps.append(
_block_bootstrap(
*arrays,
array_list,
blocks=blocks,
n_iteration=leftover,
exclude_dims=exclude_dims,
Expand All @@ -393,6 +394,6 @@ def _max_chunk_size_mb(ds):
for bootstrap in zip(*bootstraps)
)

if len(arrays) == 1:
if len(array_list) == 1:
return bootstraps_concat[0]
return bootstraps_concat
10 changes: 5 additions & 5 deletions tests/processing/test_block_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test__expand_n_nested_random_indices(indices, expected_shapes):
def test__block_bootstrap(objects, blocks, n_iteration, exclude_dims, circular, expected_shape):
"""Test _block_bootstrap works as expected"""
result = _block_bootstrap(
*objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
)
for res in result:
if isinstance(res, xr.Dataset):
Expand Down Expand Up @@ -196,15 +196,15 @@ def test__bootstrap_tuple_return():
def test__block_bootstrap_exceptions(objects, blocks, n_iteration, exclude_dims, circular, expected_exception, match):
"""Test _block_bootstrap correctly raises errors"""
with pytest.raises(expected_exception=expected_exception, match=match):
_block_bootstrap(*objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular)
_block_bootstrap(objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular)


@pytest.mark.parametrize(
"objects, blocks, n_iteration, exclude_dims, circular, expected_shape, expected_type",
[
# Single array bootstrap
(
[xr.DataArray(np.random.rand(10, 5), dims=["dim1", "dim2"])],
xr.DataArray(np.random.rand(10, 5), dims=["dim1", "dim2"]),
{"dim1": 2, "dim2": 2},
3,
None,
Expand Down Expand Up @@ -257,7 +257,7 @@ def test__block_bootstrap_exceptions(objects, blocks, n_iteration, exclude_dims,
def test_block_bootstrap(objects, blocks, n_iteration, exclude_dims, circular, expected_shape, expected_type):
"""Test block_bootstrap works as expected"""
result = block_bootstrap(
*objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
)
if expected_type == tuple:
assert isinstance(result, tuple)
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_block_bootstrap_dask(objects, blocks, n_iteration, exclude_dims, circul
pytest.skip("Dask unavailable, could not run test") # pragma: no cover

result = block_bootstrap(
*objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
objects, blocks=blocks, n_iteration=n_iteration, exclude_dims=exclude_dims, circular=circular
)
if isinstance(result, xr.DataArray):
assert isinstance(result.data, dask.array.Array)
Expand Down
Loading

0 comments on commit 9078c2d

Please sign in to comment.