Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle non-uniform chunk sizes in formatting #57

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def conservative_regrid_dataset(
# Create weights array and coverage mask for each regridding dim
weights = {}
covered = {}
for coord in coords:
for coord in coords: # noqa: PLC0206
covered[coord] = (coords[coord] <= data[coord].max()) & (
coords[coord] >= data[coord].min()
)
Expand All @@ -137,7 +137,7 @@ def conservative_regrid_dataset(
weights[coord] = da_weights

# Apply the weights, using a unique set that matches chunking of each array
for array in data_vars.keys():
for array in data_vars.keys(): # noqa: PLC0206
var_weights = {}
for coord, weight_array in weights.items():
var_input_chunks = data_vars[array].chunksizes.get(coord)
Expand Down
19 changes: 12 additions & 7 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ def format_for_regrid(
Currently handles padding of spherical geometry if lat/lon coordinates can
be inferred and the domain size requires boundary padding.
"""
orig_chunksizes = obj.chunksizes

# Special-cased coordinates with accepted names and formatting function
coord_handlers: dict[str, CoordHandler] = {
"lat": {"names": ["lat", "latitude"], "func": format_lat},
Expand All @@ -270,15 +268,22 @@ def format_for_regrid(
formatted_coords[coord_type] = str(coord)

# Apply formatting
result = obj.copy()
for coord_type, coord in formatted_coords.items():
# Make sure formatted coords are sorted
obj = ensure_monotonic(obj, coord)
result = ensure_monotonic(result, coord)
target = ensure_monotonic(target, coord)
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
result = coord_handlers[coord_type]["func"](result, target, formatted_coords)

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(coord, [])) == 1:
obj = obj.chunk({coord: -1})
return obj
if isinstance(obj, xr.DataArray) and len(obj.chunksizes.get(coord, ())) == 1:
result = result.chunk({coord: -1})
elif isinstance(obj, xr.Dataset):
for var in result.data_vars:
if len(obj[var].chunksizes.get(coord, ())) == 1:
result[var] = result[var].chunk({coord: -1})

return result


def format_lat(
Expand Down
37 changes: 37 additions & 0 deletions tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,40 @@ def test_stats():
# And preserve integer dtypes
assert formatted.data.dtype == source.data.dtype
assert (formatted.longitude.diff("longitude") == 1).all()


def test_maintain_single_chunk():
dx_source = 2
source = xarray_regrid.Grid(
north=90 - dx_source / 2,
east=360 - dx_source / 2,
south=-90 + dx_source / 2,
west=0 + dx_source / 2,
resolution_lat=dx_source,
resolution_lon=dx_source,
).create_regridding_dataset()
source["a"] = xr.DataArray(
np.ones((source.latitude.size, source.longitude.size)),
dims=["latitude", "longitude"],
coords={"latitude": source.latitude, "longitude": source.longitude},
).chunk({"latitude": -1, "longitude": -1})
source["b"] = source.a.copy().chunk({"latitude": 45, "longitude": 90})

dx_target = 1
target = xarray_regrid.Grid(
north=90,
east=360,
south=-90,
west=0,
resolution_lat=dx_target,
resolution_lon=dx_target,
).create_regridding_dataset()

# dataset
formatted = format_for_regrid(source, target)
assert formatted.a.chunks == ((92,), (182,))
assert formatted.b.chunks == ((1, 45, 45, 1), (1, 90, 90, 1))

# dataarray
formatted = format_for_regrid(source.a, target)
assert formatted.chunks == ((92,), (182,))
Loading