Skip to content

Commit

Permalink
Merge pull request #1098 from xcube-dev/konstntokas-xxx-improve_resam…
Browse files Browse the repository at this point in the history
…ple_in_space

Improve resample_in_space and rectify_dataset
  • Loading branch information
konstntokas authored Jan 2, 2025
2 parents 3473d04 + 27b4ae2 commit f3184e3
Show file tree
Hide file tree
Showing 10 changed files with 1,299 additions and 624 deletions.
11 changes: 11 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

### Enhancements


* The method `xcube.core.GridMapping.transform` now supports lazy execution. If
computations based on actual data are required—such as determining whether the
grid mapping is regular or estimating the resolution in the x or y direction—only a
single chunk is accessed whenever possible, ensuring faster performance.

* The function `xcube.core.resampling.rectify_dataset` now supports `xarray.Datasets`
containing multi-dimensional data variables structured as `var(..., y_dim, x_dim)`.
The two spatial dimensions (`y_dim` and `x_dim`) must occupy the last two positions
in the variable's dimensions.

* Added a new _preload API_ to xcube data stores:
- Enhanced the `xcube.core.store.DataStore` class to optionally support
preloading of datasets via an API represented by the
Expand Down
1,704 changes: 1,135 additions & 569 deletions examples/notebooks/resampling/reproject_large_esa_cci_landcover.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions test/core/gridmapping/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_transform(self):
self.assertEqual(pyproj.CRS.from_string("EPSG:32633"), transformed_gm.crs)
self.assertEqual((400, 200), transformed_gm.size)
self.assertEqual((400, 200), transformed_gm.tile_size)
self.assertEqual(None, transformed_gm.is_j_axis_up)
self.assertEqual(False, transformed_gm.is_j_axis_up)
self.assertEqual(
("transformed_x", "transformed_y"), transformed_gm.xy_var_names
)
Expand All @@ -335,7 +335,7 @@ def test_transform_xy_res(self):
self.assertEqual((400, 200), transformed_gm.size)
self.assertEqual((200, 200), transformed_gm.tile_size)
self.assertEqual((1000, 1000), transformed_gm.xy_res)
self.assertEqual(None, transformed_gm.is_j_axis_up)
self.assertEqual(False, transformed_gm.is_j_axis_up)
self.assertEqual(
("transformed_x", "transformed_y"), transformed_gm.xy_var_names
)
Expand Down
15 changes: 8 additions & 7 deletions test/core/gridmapping/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_from_non_regular_cube(self):
self.assertEqual(GEO_CRS, gm.crs)
self.assertEqual(False, gm.is_regular)
self.assertEqual(False, gm.is_lon_360)
self.assertEqual(None, gm.is_j_axis_up)
self.assertEqual(False, gm.is_j_axis_up)
self.assertEqual((2, 3, 4), gm.xy_coords.shape)
self.assertEqual(("coord", "y", "x"), gm.xy_coords.dims)
self.assertEqual((0.8, 0.8), gm.xy_res)
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_from_real_olci(self):
# self.assertAlmostEqual(60.63871982169044, gm.y_max)
self.assertEqual(False, gm.is_regular)
self.assertEqual(False, gm.is_lon_360)
self.assertEqual(None, gm.is_j_axis_up)
self.assertEqual(False, gm.is_j_axis_up)
self.assertEqual((2, 1890, 1189), gm.xy_coords.shape)
self.assertEqual(("coord", "y", "x"), gm.xy_coords.dims)

Expand All @@ -123,29 +123,30 @@ def test_from_real_olci(self):

def test_from_sentinel_2(self):
dataset = create_s2plus_dataset()
tol = 1e-6

gm = GridMapping.from_dataset(dataset)
gm = GridMapping.from_dataset(dataset, tolerance=tol)
# Should pick the projected one which is regular
self.assertIn("Projected", gm.crs.type_name)
self.assertEqual(True, gm.is_regular)

gm = GridMapping.from_dataset(dataset, prefer_is_regular=True)
gm = GridMapping.from_dataset(dataset, prefer_is_regular=True, tolerance=tol)
# Should pick the projected one which is regular
self.assertIn("Projected", gm.crs.type_name)
self.assertEqual(True, gm.is_regular)

gm = GridMapping.from_dataset(dataset, prefer_is_regular=False)
gm = GridMapping.from_dataset(dataset, prefer_is_regular=False, tolerance=tol)
# Should pick the geographic one which is irregular
self.assertIn("Geographic", gm.crs.type_name)
self.assertEqual(False, gm.is_regular)

gm = GridMapping.from_dataset(dataset, prefer_crs=GEO_CRS)
gm = GridMapping.from_dataset(dataset, prefer_crs=GEO_CRS, tolerance=tol)
# Should pick the geographic one which is irregular
self.assertIn("Geographic", gm.crs.type_name)
self.assertEqual(False, gm.is_regular)

gm = GridMapping.from_dataset(
dataset, prefer_crs=GEO_CRS, prefer_is_regular=True
dataset, prefer_crs=GEO_CRS, prefer_is_regular=True, tolerance=tol
)
# Should pick the geographic one which is irregular
self.assertIn("Geographic", gm.crs.type_name)
Expand Down
7 changes: 5 additions & 2 deletions test/core/resampling/test_rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,8 @@ def _assert_compute_and_extract_source_pixels(

target_rad = np.full((13, 13), np.nan, dtype=np.float64)

compute_var_image(source_ds.rad.values, dst_src_ij, target_rad, 0)
src_bbox = [0, 0, source_ds.rad.shape[-1], source_ds.rad.shape[-2]]
compute_var_image(source_ds.rad.values, dst_src_ij, target_rad, src_bbox, 0)

if not is_j_axis_up:
np.testing.assert_almost_equal(
Expand Down Expand Up @@ -975,7 +976,9 @@ def test_rectify_dataset(self):
]
)

source_gm = GridMapping.from_dataset(source_ds, prefer_crs=CRS_WGS84)
source_gm = GridMapping.from_dataset(
source_ds, prefer_crs=CRS_WGS84, tolerance=1e-6
)

target_ds = rectify_dataset(source_ds, source_gm=source_gm)
self.assertEqual(((5, 1), (5, 4)), target_ds.rrs_665.chunks)
Expand Down
57 changes: 33 additions & 24 deletions xcube/core/gridmapping/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import abc
import math
from typing import Tuple, Union, Dict
from typing import Union

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -85,6 +85,7 @@ def new_grid_mapping_from_coords(
crs: Union[str, pyproj.crs.CRS],
*,
xy_res: Union[Number, tuple[Number, Number]] = None,
xy_bbox: tuple[Number, Number, Number, Number] = None,
tile_size: Union[int, tuple[int, int]] = None,
tolerance: float = DEFAULT_TOLERANCE,
) -> GridMapping:
Expand Down Expand Up @@ -136,14 +137,14 @@ def new_grid_mapping_from_coords(
x_res = x_diff[0]
y_res = y_diff[0]
is_regular = da.allclose(x_diff, x_res, atol=tolerance) and da.allclose(
x_diff, y_res, atol=tolerance
y_diff, y_res, atol=tolerance
)
if is_regular:
x_res = round_to_fraction(x_res, 5, 0.25)
y_res = round_to_fraction(y_res, 5, 0.25)
x_res = round_to_fraction(float(x_res), 5, 0.25)
y_res = round_to_fraction(float(y_res), 5, 0.25)
else:
x_res = round_to_fraction(float(np.nanmedian(x_diff)), 2, 0.5)
y_res = round_to_fraction(float(np.nanmedian(y_diff)), 2, 0.5)
x_res = round_to_fraction(float(np.nanmedian(x_diff, axis=0)), 2, 0.5)
y_res = round_to_fraction(float(np.nanmedian(y_diff, axis=0)), 2, 0.5)

if (
tile_size is None
Expand Down Expand Up @@ -176,10 +177,10 @@ def new_grid_mapping_from_coords(
x = da.asarray(x_coords)
y = da.asarray(y_coords)

x_x_diff = _abs_no_nan(da.diff(x[0, :]))
x_y_diff = _abs_no_nan(da.diff(x[:, 0]))
y_x_diff = _abs_no_nan(da.diff(y[0, :]))
y_y_diff = _abs_no_nan(da.diff(y[:, 0]))
x_x_diff = _abs_no_nan(da.diff(x[0, : x.chunksize[1]]))
x_y_diff = _abs_no_nan(da.diff(x[: x.chunksize[0], 0]))
y_x_diff = _abs_no_nan(da.diff(y[0, : x.chunksize[0]]))
y_y_diff = _abs_no_nan(da.diff(y[: x.chunksize[1], 0]))

if not is_lon_360 and crs.is_geographic:
is_anti_meridian_crossed = da.any(da.max(x_x_diff) > 180) or da.any(
Expand All @@ -201,8 +202,8 @@ def new_grid_mapping_from_coords(
is_regular = (
da.allclose(x_x_diff, x_res, atol=tolerance)
and da.allclose(y_y_diff, y_res, atol=tolerance)
and da.all(x_y_diff == 0)
and da.all(y_x_diff == 0)
and da.allclose(x_y_diff, 0, atol=tolerance)
and da.allclose(y_x_diff, 0, atol=tolerance)
)

if not is_regular and xy_res is None:
Expand Down Expand Up @@ -276,7 +277,9 @@ def new_grid_mapping_from_coords(
)

# Guess j axis direction
is_j_axis_up = da.all(y_coords[0, :] < y_coords[-1, :]) or None
is_j_axis_up = da.all(
y_coords[0, : y.chunksize[1]] < y_coords[-1, : y.chunksize[1]]
)

assert_true(
x_res > 0 and y_res > 0,
Expand All @@ -285,11 +288,17 @@ def new_grid_mapping_from_coords(
)

x_res, y_res = _to_int_or_float(x_res), _to_int_or_float(y_res)
x_res_05, y_res_05 = x_res / 2, y_res / 2
x_min = _to_int_or_float(x_coords.min() - x_res_05)
y_min = _to_int_or_float(y_coords.min() - y_res_05)
x_max = _to_int_or_float(x_coords.max() + x_res_05)
y_max = _to_int_or_float(y_coords.max() + y_res_05)
if xy_bbox is None:
x_res_05, y_res_05 = x_res / 2, y_res / 2
x_min = _to_int_or_float(x_coords[..., 0].min() - x_res_05)
x_max = _to_int_or_float(x_coords[..., -1].max() + x_res_05)
if is_j_axis_up:
y_min = _to_int_or_float(float(y_coords[0, ...].min()) - y_res_05)
y_max = _to_int_or_float(float(y_coords[-1, ...].max()) + y_res_05)
else:
y_min = _to_int_or_float(float(y_coords[-1, ...].min()) - y_res_05)
y_max = _to_int_or_float(float(y_coords[0, ...].max()) + y_res_05)
xy_bbox = (x_min, y_min, x_max, y_max)

if cls is Coords1DGridMapping and is_regular:
from .regular import RegularGridMapping
Expand All @@ -302,7 +311,7 @@ def new_grid_mapping_from_coords(
crs=crs,
size=size,
tile_size=tile_size,
xy_bbox=(x_min, y_min, x_max, y_max),
xy_bbox=xy_bbox,
xy_res=(x_res, y_res),
xy_var_names=xy_var_names,
xy_dim_names=(str(x_dim), str(y_dim)),
Expand All @@ -313,13 +322,13 @@ def new_grid_mapping_from_coords(


def _abs_no_zero(array: Union[xr.DataArray, da.Array, np.ndarray]):
array = np.fabs(array)
return np.where(np.isclose(array, 0), np.nan, array)
array = da.fabs(array)
return da.where(da.isclose(array, 0), np.nan, array)


def _abs_no_nan(array: Union[xr.DataArray, da.Array, np.ndarray]):
array = np.fabs(array)
return np.where(np.logical_or(np.isnan(array), np.isclose(array, 0)), 0, array)
def _abs_no_nan(array: Union[da.Array, np.ndarray]):
array = da.fabs(array)
return da.where(da.logical_or(da.isnan(array), da.isclose(array, 0)), 0, array)


def grid_mapping_to_coords(
Expand Down
8 changes: 4 additions & 4 deletions xcube/core/gridmapping/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def new_grid_mapping_from_dataset(
prefer_crs: Union[str, pyproj.crs.CRS] = None,
prefer_is_regular: bool = None,
emit_warnings: bool = False,
tolerance: float = DEFAULT_TOLERANCE
tolerance: float = DEFAULT_TOLERANCE,
) -> Optional[GridMapping]:
# Note `crs` is used if CRS is known in advance,
# so the code forces its use. `prefer_crs` is used if
Expand Down Expand Up @@ -59,13 +59,13 @@ def new_grid_mapping_from_dataset(
if len(grid_mappings) > 1:
if prefer_crs is not None and prefer_is_regular is not None:
for gm in grid_mappings:
if gm.crs == prefer_crs and gm.is_regular == prefer_is_regular:
if gm.crs == prefer_crs and bool(gm.is_regular) == prefer_is_regular:
return gm
for gm in grid_mappings:
if (
gm.crs.is_geographic
and prefer_crs.is_geographic
and gm.is_regular == prefer_is_regular
and bool(gm.is_regular) == prefer_is_regular
):
return gm

Expand All @@ -79,7 +79,7 @@ def new_grid_mapping_from_dataset(

if prefer_is_regular is not None:
for gm in grid_mappings:
if gm.is_regular == prefer_is_regular:
if bool(gm.is_regular) == prefer_is_regular:
return gm

# Get arbitrary one (here: first)
Expand Down
4 changes: 3 additions & 1 deletion xcube/core/gridmapping/regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _new_xy_coords(self) -> xr.DataArray:
xy_coords = da.concatenate(
[da.expand_dims(x_coords_2d, 0), da.expand_dims(y_coords_2d, 0)]
)
xy_coords = da.rechunk(xy_coords, chunks=(2, 512, 512))
xy_coords = da.rechunk(
xy_coords, chunks=(2, xy_coords.chunksize[1], xy_coords.chunksize[2])
)
xy_coords = xr.DataArray(
xy_coords,
dims=("coord", self.y_coords.dims[0], self.x_coords.dims[0]),
Expand Down
45 changes: 44 additions & 1 deletion xcube/core/gridmapping/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Permissions are hereby granted under the terms of the MIT License:
# https://opensource.org/licenses/MIT.

from typing import Union, Tuple
from typing import Union

import numpy as np
import pyproj
Expand All @@ -13,6 +13,7 @@
from .base import GridMapping
from .coords import new_grid_mapping_from_coords
from .helpers import _assert_valid_xy_names
from .helpers import _normalize_number_pair
from .helpers import Number
from .helpers import _normalize_crs

Expand Down Expand Up @@ -70,6 +71,47 @@ def _transform(block: np.ndarray) -> np.ndarray:
output_dtypes=[np.float64],
dask="parallelized",
)
if xy_res is not None:
if grid_mapping.is_j_axis_up:
gm_ymin = grid_mapping.y_coords[0].values
gm_ymax = grid_mapping.y_coords[-1].values
else:
gm_ymin = grid_mapping.y_coords[-1].values
gm_ymax = grid_mapping.y_coords[0].values
y_min = np.min(
transformer.transform(
grid_mapping.x_coords.values,
np.repeat(gm_ymin, grid_mapping.size[0]),
)[1]
)
y_max = np.max(
transformer.transform(
grid_mapping.x_coords.values,
np.repeat(gm_ymax, grid_mapping.size[0]),
)[1]
)
x_min = np.min(
transformer.transform(
np.repeat(grid_mapping.x_coords[0].values, grid_mapping.size[1]),
grid_mapping.y_coords.values,
)[0]
)
x_max = np.max(
transformer.transform(
np.repeat(grid_mapping.x_coords[-1].values, grid_mapping.size[1]),
grid_mapping.y_coords.values,
)[0]
)
x_res, y_res = _normalize_number_pair(xy_res)
x_res_05, y_res_05 = x_res / 2, y_res / 2
xy_bbox = (
x_min - x_res_05,
y_min - y_res_05,
x_max + x_res_05,
y_max + y_res_05,
)
else:
xy_bbox = None

xy_var_names = xy_var_names or ("transformed_x", "transformed_y")

Expand All @@ -90,6 +132,7 @@ def _transform(block: np.ndarray) -> np.ndarray:
y_coords=xy_coords[1].rename(xy_var_names[1]),
crs=target_crs,
xy_res=xy_res,
xy_bbox=xy_bbox,
tile_size=tile_size,
tolerance=tolerance,
)
Loading

0 comments on commit f3184e3

Please sign in to comment.