diff --git a/docs/api.rst b/docs/api.rst index 723ec9a..81d9f83 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,5 +11,6 @@ Top level API :toctree: generated/ pyramid_coarsen + pyramid_create pyramid_reproject pyramid_regrid diff --git a/docs/generate-pyramids.md b/docs/generate-pyramids.md index 2e1ca0e..04464b9 100644 --- a/docs/generate-pyramids.md +++ b/docs/generate-pyramids.md @@ -23,3 +23,25 @@ pyramid = pyramid_reproject(ds, levels=2) # write the pyramid to zarr pyramid.to_zarr('./path/to/write') ``` + +There's also `pyramid_create` -- a more versatile alternative to `pyramid_coarsen`. +This function accepts a custom function with the signature: `ds`, `factor`, `dims`. +Here, the `sel_coarsen` function uses `ds.sel` to perform coarsening: + +```python +from ndpyramid import pyramid_create + +def sel_coarsen(ds, factor, dims, **kwargs): + return ds.sel(**{dim: slice(None, None, factor) for dim in dims}) + +factors = [4, 2, 1] +pyramid = pyramid_create( + temperature, + dims=('lat', 'lon'), + factors=factors, + boundary='trim', + func=sel_coarsen, + method_label="slice_coarsen", + type_label='pick', +) +``` diff --git a/ndpyramid/__init__.py b/ndpyramid/__init__.py index 23a2d95..9f126bd 100644 --- a/ndpyramid/__init__.py +++ b/ndpyramid/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa +from .create import pyramid_create from .coarsen import pyramid_coarsen from .reproject import pyramid_reproject from .regrid import pyramid_regrid diff --git a/ndpyramid/coarsen.py b/ndpyramid/coarsen.py index 5e941bf..5ef1076 100644 --- a/ndpyramid/coarsen.py +++ b/ndpyramid/coarsen.py @@ -3,7 +3,7 @@ import datatree as dt import xarray as xr -from .utils import get_version, multiscales_template +from .create import pyramid_create def pyramid_coarsen( @@ -23,28 +23,17 @@ def pyramid_coarsen( Additional keyword arguments to pass to xarray.Dataset.coarsen. """ - # multiscales spec - save_kwargs = locals() - del save_kwargs['ds'] - - attrs = { - 'multiscales': multiscales_template( - datasets=[{'path': str(i)} for i in range(len(factors))], - type='reduce', - method='pyramid_coarsen', - version=get_version(), - kwargs=save_kwargs, - ) - } - - # set up pyramid - plevels = {} - - # pyramid data - for key, factor in enumerate(factors): + def coarsen(ds: xr.Dataset, factor: int, dims: list[str], **kwargs): # merge dictionary via union operator kwargs |= {d: factor for d in dims} - plevels[str(key)] = ds.coarsen(**kwargs).mean() # type: ignore - - plevels['/'] = xr.Dataset(attrs=attrs) - return dt.DataTree.from_dict(plevels) + return ds.coarsen(**kwargs).mean() # type: ignore + + return pyramid_create( + ds, + factors=factors, + dims=dims, + func=coarsen, + method_label='pyramid_coarsen', + type_label='reduce', + **kwargs, + ) diff --git a/ndpyramid/create.py b/ndpyramid/create.py new file mode 100644 index 0000000..3bbeabb --- /dev/null +++ b/ndpyramid/create.py @@ -0,0 +1,70 @@ +from __future__ import annotations # noqa: F401 + +from typing import Callable + +import datatree as dt +import xarray as xr + +from .utils import get_version, multiscales_template + + +def pyramid_create( + ds: xr.Dataset, + *, + factors: list[int], + dims: list[str], + func: Callable, + type_label: str = 'reduce', + method_label: str | None = None, + **kwargs, +): + """Create a multiscale pyramid via a given function applied to a dataset. + The generalized version of pyramid_coarsen. + + Parameters + ---------- + ds : xarray.Dataset + The dataset to apply the function to. + factors : list[int] + The factors to coarsen by. + dims : list[str] + The dimensions to coarsen. + func : Callable + The function to apply to the dataset; must accept the + `ds`, `factor`, and `dims` as positional arguments. + type_label : str, optional + The type label to use as metadata for the multiscales spec. + The default is 'reduce'. + method_label : str, optional + The method label to use as metadata for the multiscales spec. + The default is the name of the function. + kwargs : dict + Additional keyword arguments to pass to the func. + + """ + # multiscales spec + save_kwargs = locals() + del save_kwargs['ds'] + del save_kwargs['func'] + del save_kwargs['type_label'] + del save_kwargs['method_label'] + + attrs = { + 'multiscales': multiscales_template( + datasets=[{'path': str(i)} for i in range(len(factors))], + type=type_label, + method=method_label or func.__name__, + version=get_version(), + kwargs=save_kwargs, + ) + } + + # set up pyramid + plevels = {} + + # pyramid data + for key, factor in enumerate(factors): + plevels[str(key)] = func(ds, factor, dims, **kwargs) + + plevels['/'] = xr.Dataset(attrs=attrs) + return dt.DataTree.from_dict(plevels) diff --git a/tests/test_pyramids.py b/tests/test_pyramids.py index e5f60be..f38bc5f 100644 --- a/tests/test_pyramids.py +++ b/tests/test_pyramids.py @@ -3,7 +3,7 @@ import xarray as xr from zarr.storage import MemoryStore -from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject +from ndpyramid import pyramid_coarsen, pyramid_create, pyramid_regrid, pyramid_reproject from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds @@ -21,6 +21,32 @@ def test_xarray_coarsened_pyramid(temperature, benchmark): ) assert pyramid.ds.attrs['multiscales'] assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors) + assert pyramid.ds.attrs['multiscales'][0]['metadata']['method'] == 'pyramid_coarsen' + assert pyramid.ds.attrs['multiscales'][0]['type'] == 'reduce' + pyramid.to_zarr(MemoryStore()) + + +@pytest.mark.parametrize('method_label', [None, 'sel_coarsen']) +def test_xarray_custom_coarsened_pyramid(temperature, benchmark, method_label): + def sel_coarsen(ds, factor, dims, **kwargs): + return ds.sel(**{dim: slice(None, None, factor) for dim in dims}) + + factors = [4, 2, 1] + pyramid = benchmark( + lambda: pyramid_create( + temperature, + dims=('lat', 'lon'), + factors=factors, + boundary='trim', + func=sel_coarsen, + method_label=method_label, + type_label='pick', + ) + ) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors) + assert pyramid.ds.attrs['multiscales'][0]['metadata']['method'] == 'sel_coarsen' + assert pyramid.ds.attrs['multiscales'][0]['type'] == 'pick' pyramid.to_zarr(MemoryStore())