diff --git a/src/xarray_regrid/__init__.py b/src/xarray_regrid/__init__.py index a8d51e9..e73580d 100644 --- a/src/xarray_regrid/__init__.py +++ b/src/xarray_regrid/__init__.py @@ -1,3 +1,4 @@ +from xarray_regrid import methods from xarray_regrid.regrid import Regridder from xarray_regrid.utils import Grid, create_regridding_dataset @@ -5,4 +6,7 @@ "Grid", "Regridder", "create_regridding_dataset", + "methods", ] + +__version__ = "0.2.0" diff --git a/src/xarray_regrid/methods.py b/src/xarray_regrid/methods/conservative.py similarity index 84% rename from src/xarray_regrid/methods.py rename to src/xarray_regrid/methods/conservative.py index a4ce8f8..5ecc307 100644 --- a/src/xarray_regrid/methods.py +++ b/src/xarray_regrid/methods/conservative.py @@ -1,5 +1,6 @@ +"""Conservative regridding implementation.""" from collections.abc import Hashable -from typing import Literal, overload +from typing import overload import dask.array import numpy as np @@ -8,48 +9,6 @@ from xarray_regrid import utils -@overload -def interp_regrid( - data: xr.DataArray, - target_ds: xr.Dataset, - method: Literal["linear", "nearest", "cubic"], -) -> xr.DataArray: - ... - - -@overload -def interp_regrid( - data: xr.Dataset, - target_ds: xr.Dataset, - method: Literal["linear", "nearest", "cubic"], -) -> xr.Dataset: - ... - - -def interp_regrid( - data: xr.DataArray | xr.Dataset, - target_ds: xr.Dataset, - method: Literal["linear", "nearest", "cubic"], -) -> xr.DataArray | xr.Dataset: - """Refine a dataset using xarray's interp method. - - Args: - data: Input dataset. - target_ds: Dataset which coordinates the input dataset should be regrid to. - method: Which interpolation method to use (e.g. 'linear', 'nearest'). - - Returns: - Regridded input dataset - """ - coord_names = set(target_ds.coords).intersection(set(data.coords)) - coords = {name: target_ds[name] for name in coord_names} - - return data.interp( - coords=coords, - method=method, - ) - - @overload def conservative_regrid( data: xr.DataArray, @@ -173,14 +132,13 @@ def apply_weights( da: xr.DataArray, weights: np.ndarray, coord_name: Hashable, new_coords: np.ndarray ) -> xr.DataArray: """Apply the weights to convert data to the new coordinates.""" + new_data: np.ndarray | dask.array.Array if da.chunks is not None: # Dask routine - new_data = dask.array.einsum( - "i...,ij->j...", da.data, weights, optimize="greedy" - ) + new_data = compute_einsum_dask(da, weights) else: # numpy routine - new_data = np.einsum("i...,ij->j...", da.data, weights) + new_data = compute_einsum_numpy(da, weights) coord_mapping = {coord_name: new_coords} coords = list(da.dims) @@ -195,6 +153,36 @@ def apply_weights( ) +def compute_einsum_dask(da: xr.DataArray, weights: np.ndarray) -> dask.array.Array: + """Compute the einsum between dask data and weights, and mask NaNs if needed.""" + new_data: dask.array.Array + if np.any(np.isnan(da.data)): + new_data = dask.array.einsum( + "i...,ij->j...", da.fillna(0).data, weights, optimize="greedy" + ) + isnan = dask.array.einsum( + "i...,ij->j...", np.isnan(da.data), weights, optimize="greedy" + ) + new_data[isnan > 0] = np.nan + else: + new_data = dask.array.einsum( + "i...,ij->j...", da.data, weights, optimize="greedy" + ) + return new_data + + +def compute_einsum_numpy(da: xr.DataArray, weights: np.ndarray) -> np.ndarray: + """Compute the einsum between numpy data and weights, and mask NaNs if needed.""" + new_data: np.ndarray + if np.any(np.isnan(da.data)): + new_data = np.einsum("i...,ij->j...", da.fillna(0).data, weights) + isnan = np.einsum("i...,ij->j...", np.isnan(da.data), weights) + new_data[isnan > 0] = np.nan + else: + new_data = np.einsum("i...,ij->j...", da.data, weights) + return new_data + + def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray: """Determine the weights to map from the old coordinates to the new coordinates. diff --git a/src/xarray_regrid/methods/interp.py b/src/xarray_regrid/methods/interp.py new file mode 100644 index 0000000..7295bfc --- /dev/null +++ b/src/xarray_regrid/methods/interp.py @@ -0,0 +1,46 @@ +"""Methods based on xr.interp.""" +from typing import Literal, overload + +import xarray as xr + + +@overload +def interp_regrid( + data: xr.DataArray, + target_ds: xr.Dataset, + method: Literal["linear", "nearest", "cubic"], +) -> xr.DataArray: + ... + + +@overload +def interp_regrid( + data: xr.Dataset, + target_ds: xr.Dataset, + method: Literal["linear", "nearest", "cubic"], +) -> xr.Dataset: + ... + + +def interp_regrid( + data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + method: Literal["linear", "nearest", "cubic"], +) -> xr.DataArray | xr.Dataset: + """Refine a dataset using xarray's interp method. + + Args: + data: Input dataset. + target_ds: Dataset which coordinates the input dataset should be regrid to. + method: Which interpolation method to use (e.g. 'linear', 'nearest'). + + Returns: + Regridded input dataset + """ + coord_names = set(target_ds.coords).intersection(set(data.coords)) + coords = {name: target_ds[name] for name in coord_names} + + return data.interp( + coords=coords, + method=method, + ) diff --git a/src/xarray_regrid/most_common.py b/src/xarray_regrid/methods/most_common.py similarity index 100% rename from src/xarray_regrid/most_common.py rename to src/xarray_regrid/methods/most_common.py diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index e34ff70..365332d 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,6 +1,6 @@ import xarray as xr -from xarray_regrid import methods, most_common +from xarray_regrid.methods import conservative, interp, most_common @xr.register_dataarray_accessor("regrid") @@ -34,7 +34,7 @@ def linear( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return methods.interp_regrid(self._obj, ds_target_grid, "linear") + return interp.interp_regrid(self._obj, ds_target_grid, "linear") def nearest( self, @@ -51,7 +51,7 @@ def nearest( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return methods.interp_regrid(self._obj, ds_target_grid, "nearest") + return interp.interp_regrid(self._obj, ds_target_grid, "nearest") def cubic( self, @@ -68,7 +68,7 @@ def cubic( Returns: Data regridded to the target dataset coordinates. """ - return methods.interp_regrid(self._obj, ds_target_grid, "cubic") + return interp.interp_regrid(self._obj, ds_target_grid, "cubic") def conservative( self, @@ -89,7 +89,9 @@ def conservative( """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return methods.conservative_regrid(self._obj, ds_target_grid, latitude_coord) + return conservative.conservative_regrid( + self._obj, ds_target_grid, latitude_coord + ) def most_common( self, diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 3d9105e..b7511ce 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -100,3 +100,26 @@ def test_conservative_regridder(conservative_input_data, conservative_sample_gri rtol=0.002, atol=2e-6, ) + + +def test_conservative_nans(conservative_input_data, conservative_sample_grid): + ds = conservative_input_data + ds["tp"] = ds["tp"].where(ds.latitude >= 0).where(ds.longitude < 180) + ds_regrid = ds.regrid.conservative( + conservative_sample_grid, latitude_coord="latitude" + ) + ds_cdo = xr.open_dataset(CDO_DATA["conservative"]) + + # Cut of the edges: edge performance to be improved later (hopefully) + no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)} + no_nans = {"latitude": slice(1, 90), "longitude": slice(None, 179)} + xr.testing.assert_allclose( + ds_regrid["tp"] + .sel(no_edges) + .sel(no_nans) + .compute() + .transpose("time", "latitude", "longitude"), + ds_cdo["tp"].sel(no_edges).sel(no_nans).compute(), + rtol=0.002, + atol=2e-6, + )