diff --git a/CHANGELOG.md b/CHANGELOG.md index 465b646..1e97d87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). Fixed: - Attributes are now properly preserved when updating coordinates during pre-formatting for regridding ([#54](https://github.com/xarray-contrib/xarray-regrid/pull/54)). + - Handle datasets with inconsistent chunksizes during pre-formatting ([#57](https://github.com/xarray-contrib/xarray-regrid/pull/57)). ## 0.4.0 (2024-09-26) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 8d7c3e2..2ab67d9 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -119,7 +119,7 @@ def conservative_regrid_dataset( # Create weights array and coverage mask for each regridding dim weights = {} covered = {} - for coord in coords: + for coord in coords: # noqa: PLC0206 covered[coord] = (coords[coord] <= data[coord].max()) & ( coords[coord] >= data[coord].min() ) @@ -137,7 +137,7 @@ def conservative_regrid_dataset( weights[coord] = da_weights # Apply the weights, using a unique set that matches chunking of each array - for array in data_vars.keys(): + for array in data_vars.keys(): # noqa: PLC0206 var_weights = {} for coord, weight_array in weights.items(): var_input_chunks = data_vars[array].chunksizes.get(coord) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 9cbcc03..59a6e8f 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -249,8 +249,6 @@ def format_for_regrid( Currently handles padding of spherical geometry if lat/lon coordinates can be inferred and the domain size requires boundary padding. """ - orig_chunksizes = obj.chunksizes - # Special-cased coordinates with accepted names and formatting function coord_handlers: dict[str, CoordHandler] = { "lat": {"names": ["lat", "latitude"], "func": format_lat}, @@ -270,15 +268,22 @@ def format_for_regrid( formatted_coords[coord_type] = str(coord) # Apply formatting + result = obj.copy() for coord_type, coord in formatted_coords.items(): # Make sure formatted coords are sorted - obj = ensure_monotonic(obj, coord) + result = ensure_monotonic(result, coord) target = ensure_monotonic(target, coord) - obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords) + result = coord_handlers[coord_type]["func"](result, target, formatted_coords) + # Coerce back to a single chunk if that's what was passed - if len(orig_chunksizes.get(coord, [])) == 1: - obj = obj.chunk({coord: -1}) - return obj + if isinstance(obj, xr.DataArray) and len(obj.chunksizes.get(coord, ())) == 1: + result = result.chunk({coord: -1}) + elif isinstance(obj, xr.Dataset): + for var in result.data_vars: + if len(obj[var].chunksizes.get(coord, ())) == 1: + result[var] = result[var].chunk({coord: -1}) + + return result def format_lat( diff --git a/tests/test_format.py b/tests/test_format.py index 3c9a35b..207e4e5 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -216,3 +216,40 @@ def test_stats(): # And preserve integer dtypes assert formatted.data.dtype == source.data.dtype assert (formatted.longitude.diff("longitude") == 1).all() + + +def test_maintain_single_chunk(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90 - dx_source / 2, + east=360 - dx_source / 2, + south=-90 + dx_source / 2, + west=0 + dx_source / 2, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + source["a"] = xr.DataArray( + np.ones((source.latitude.size, source.longitude.size)), + dims=["latitude", "longitude"], + coords={"latitude": source.latitude, "longitude": source.longitude}, + ).chunk({"latitude": -1, "longitude": -1}) + source["b"] = source.a.copy().chunk({"latitude": 45, "longitude": 90}) + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + # dataset + formatted = format_for_regrid(source, target) + assert formatted.a.chunks == ((92,), (182,)) + assert formatted.b.chunks == ((1, 45, 45, 1), (1, 90, 90, 1)) + + # dataarray + formatted = format_for_regrid(source.a, target) + assert formatted.chunks == ((92,), (182,))