Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyramid create for creating pyramids with custom funcs #120

Merged
merged 6 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/generate-pyramids.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,25 @@ pyramid = pyramid_reproject(ds, levels=2)
# write the pyramid to zarr
pyramid.to_zarr('./path/to/write')
```

There's also `pyramid_create`--a more versatile alternative to pyramid_coarsen.
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved

This function accepts a custom function with the signature: `ds`, `factor`, `dims`.

Here, the `sel_coarsen` function uses `ds.sel` to perform coarsening:

```python
def sel_coarsen(ds, factor, dims, **kwargs):
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
return ds.sel(**{dim: slice(None, None, factor) for dim in dims})

factors = [4, 2, 1]
pyramid = pyramid_create(
temperature,
dims=('lat', 'lon'),
factors=factors,
boundary='trim',
func=sel_coarsen,
method_label=method_label,
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
type_label='pick',
)
```
1 change: 1 addition & 0 deletions ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa

from .create import pyramid_create
from .coarsen import pyramid_coarsen
from .reproject import pyramid_reproject
from .regrid import pyramid_regrid
Expand Down
37 changes: 13 additions & 24 deletions ndpyramid/coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datatree as dt
import xarray as xr

from .utils import get_version, multiscales_template
from .create import pyramid_create


def pyramid_coarsen(
Expand All @@ -23,28 +23,17 @@ def pyramid_coarsen(
Additional keyword arguments to pass to xarray.Dataset.coarsen.
"""

# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
type='reduce',
method='pyramid_coarsen',
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
def coarsen(ds: xr.Dataset, factor: int, **kwargs):
# merge dictionary via union operator
kwargs |= {d: factor for d in dims}
plevels[str(key)] = ds.coarsen(**kwargs).mean() # type: ignore

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
return ds.coarsen(**kwargs).mean() # type: ignore

return pyramid_create(
ds,
factors=factors,
dims=dims,
func=coarsen,
method_label='pyramid_coarsen',
type_label='reduce',
**kwargs,
)
67 changes: 67 additions & 0 deletions ndpyramid/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations # noqa: F401

from typing import Callable

import datatree as dt
import xarray as xr

from .utils import get_version, multiscales_template


def pyramid_create(
ds: xr.Dataset,
*,
factors: list[int],
dims: list[str],
func: Callable,
type_label: str = 'reduce',
method_label: str | None = None,
**kwargs,
):
"""Create a multiscale pyramid via a given function applied to a dataset.
The generalized version of pyramid_coarsen.

Parameters
----------
ds : xarray.Dataset
The dataset to apply the function to.
factors : list[int]
The factors to coarsen by.
dims : list[str]
The dimensions to coarsen.
func : Callable
The function to apply to the dataset; must accept the
`ds`, `factor`, and `dims` as positional arguments.
type_label : str, optional
The type label to use as metadata for the multiscales spec.
The default is 'reduce'.
method_label : str, optional
The method label to use as metadata for the multiscales spec.
The default is the name of the function.
kwargs : dict
Additional keyword arguments to pass to the func.

"""
# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
type=type_label,
method=method_label or func.__name__,
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
plevels[str(key)] = func(ds, factor, dims, **kwargs)

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
28 changes: 27 additions & 1 deletion tests/test_pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr
from zarr.storage import MemoryStore

from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject
from ndpyramid import pyramid_coarsen, pyramid_create, pyramid_regrid, pyramid_reproject
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds


Expand All @@ -21,6 +21,32 @@ def test_xarray_coarsened_pyramid(temperature, benchmark):
)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors)
assert pyramid.ds.attrs['multiscales'][0]['method'] == 'pyramid_coarsen'
assert pyramid.ds.attrs['multiscales'][0]['type'] == 'reduce'
pyramid.to_zarr(MemoryStore())


@pytest.mark.parametrize('method_label', [None, 'sel_coarsen'])
def test_xarray_custom_coarsened_pyramid(temperature, benchmark, method_label):
def sel_coarsen(ds, factor, dims, **kwargs):
return ds.sel(**{dim: slice(None, None, factor) for dim in dims})

factors = [4, 2, 1]
pyramid = benchmark(
lambda: pyramid_create(
temperature,
dims=('lat', 'lon'),
factors=factors,
boundary='trim',
func=sel_coarsen,
method_label=method_label,
type_label='pick',
)
)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors)
assert pyramid.ds.attrs['multiscales'][0]['method'] == 'sel_coarsen'
assert pyramid.ds.attrs['multiscales'][0]['type'] == 'pick'
pyramid.to_zarr(MemoryStore())


Expand Down