Skip to content

Commit

Permalink
Refactor methods to module. Fix conservative nan bug. (#19)
Browse files Browse the repository at this point in the history
* Refactor methods to module. Fix conservative nan bug.

* Fix linting and typing issues
  • Loading branch information
BSchilperoort authored Oct 5, 2023
1 parent 56b60fa commit 7937260
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 52 deletions.
4 changes: 4 additions & 0 deletions src/xarray_regrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from xarray_regrid import methods
from xarray_regrid.regrid import Regridder
from xarray_regrid.utils import Grid, create_regridding_dataset

__all__ = [
"Grid",
"Regridder",
"create_regridding_dataset",
"methods",
]

__version__ = "0.2.0"
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Conservative regridding implementation."""
from collections.abc import Hashable
from typing import Literal, overload
from typing import overload

import dask.array
import numpy as np
Expand All @@ -8,48 +9,6 @@
from xarray_regrid import utils


@overload
def interp_regrid(
data: xr.DataArray,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.DataArray:
...


@overload
def interp_regrid(
data: xr.Dataset,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.Dataset:
...


def interp_regrid(
data: xr.DataArray | xr.Dataset,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.DataArray | xr.Dataset:
"""Refine a dataset using xarray's interp method.
Args:
data: Input dataset.
target_ds: Dataset which coordinates the input dataset should be regrid to.
method: Which interpolation method to use (e.g. 'linear', 'nearest').
Returns:
Regridded input dataset
"""
coord_names = set(target_ds.coords).intersection(set(data.coords))
coords = {name: target_ds[name] for name in coord_names}

return data.interp(
coords=coords,
method=method,
)


@overload
def conservative_regrid(
data: xr.DataArray,
Expand Down Expand Up @@ -173,14 +132,13 @@ def apply_weights(
da: xr.DataArray, weights: np.ndarray, coord_name: Hashable, new_coords: np.ndarray
) -> xr.DataArray:
"""Apply the weights to convert data to the new coordinates."""
new_data: np.ndarray | dask.array.Array
if da.chunks is not None:
# Dask routine
new_data = dask.array.einsum(
"i...,ij->j...", da.data, weights, optimize="greedy"
)
new_data = compute_einsum_dask(da, weights)
else:
# numpy routine
new_data = np.einsum("i...,ij->j...", da.data, weights)
new_data = compute_einsum_numpy(da, weights)

coord_mapping = {coord_name: new_coords}
coords = list(da.dims)
Expand All @@ -195,6 +153,36 @@ def apply_weights(
)


def compute_einsum_dask(da: xr.DataArray, weights: np.ndarray) -> dask.array.Array:
"""Compute the einsum between dask data and weights, and mask NaNs if needed."""
new_data: dask.array.Array
if np.any(np.isnan(da.data)):
new_data = dask.array.einsum(
"i...,ij->j...", da.fillna(0).data, weights, optimize="greedy"
)
isnan = dask.array.einsum(
"i...,ij->j...", np.isnan(da.data), weights, optimize="greedy"
)
new_data[isnan > 0] = np.nan
else:
new_data = dask.array.einsum(
"i...,ij->j...", da.data, weights, optimize="greedy"
)
return new_data


def compute_einsum_numpy(da: xr.DataArray, weights: np.ndarray) -> np.ndarray:
"""Compute the einsum between numpy data and weights, and mask NaNs if needed."""
new_data: np.ndarray
if np.any(np.isnan(da.data)):
new_data = np.einsum("i...,ij->j...", da.fillna(0).data, weights)
isnan = np.einsum("i...,ij->j...", np.isnan(da.data), weights)
new_data[isnan > 0] = np.nan
else:
new_data = np.einsum("i...,ij->j...", da.data, weights)
return new_data


def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
"""Determine the weights to map from the old coordinates to the new coordinates.
Expand Down
46 changes: 46 additions & 0 deletions src/xarray_regrid/methods/interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Methods based on xr.interp."""
from typing import Literal, overload

import xarray as xr


@overload
def interp_regrid(
data: xr.DataArray,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.DataArray:
...


@overload
def interp_regrid(
data: xr.Dataset,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.Dataset:
...


def interp_regrid(
data: xr.DataArray | xr.Dataset,
target_ds: xr.Dataset,
method: Literal["linear", "nearest", "cubic"],
) -> xr.DataArray | xr.Dataset:
"""Refine a dataset using xarray's interp method.
Args:
data: Input dataset.
target_ds: Dataset which coordinates the input dataset should be regrid to.
method: Which interpolation method to use (e.g. 'linear', 'nearest').
Returns:
Regridded input dataset
"""
coord_names = set(target_ds.coords).intersection(set(data.coords))
coords = {name: target_ds[name] for name in coord_names}

return data.interp(
coords=coords,
method=method,
)
File renamed without changes.
12 changes: 7 additions & 5 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import xarray as xr

from xarray_regrid import methods, most_common
from xarray_regrid.methods import conservative, interp, most_common


@xr.register_dataarray_accessor("regrid")
Expand Down Expand Up @@ -34,7 +34,7 @@ def linear(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return methods.interp_regrid(self._obj, ds_target_grid, "linear")
return interp.interp_regrid(self._obj, ds_target_grid, "linear")

def nearest(
self,
Expand All @@ -51,7 +51,7 @@ def nearest(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return methods.interp_regrid(self._obj, ds_target_grid, "nearest")
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")

def cubic(
self,
Expand All @@ -68,7 +68,7 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return methods.interp_regrid(self._obj, ds_target_grid, "cubic")
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")

def conservative(
self,
Expand All @@ -89,7 +89,9 @@ def conservative(
"""

ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord)
return conservative.conservative_regrid(
self._obj, ds_target_grid, latitude_coord
)

def most_common(
self,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,26 @@ def test_conservative_regridder(conservative_input_data, conservative_sample_gri
rtol=0.002,
atol=2e-6,
)


def test_conservative_nans(conservative_input_data, conservative_sample_grid):
ds = conservative_input_data
ds["tp"] = ds["tp"].where(ds.latitude >= 0).where(ds.longitude < 180)
ds_regrid = ds.regrid.conservative(
conservative_sample_grid, latitude_coord="latitude"
)
ds_cdo = xr.open_dataset(CDO_DATA["conservative"])

# Cut of the edges: edge performance to be improved later (hopefully)
no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)}
no_nans = {"latitude": slice(1, 90), "longitude": slice(None, 179)}
xr.testing.assert_allclose(
ds_regrid["tp"]
.sel(no_edges)
.sel(no_nans)
.compute()
.transpose("time", "latitude", "longitude"),
ds_cdo["tp"].sel(no_edges).sel(no_nans).compute(),
rtol=0.002,
atol=2e-6,
)

0 comments on commit 7937260

Please sign in to comment.