Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encode_cf, decode_cf #69

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ dmypy.json

# sphinx
doc/source/generated
doc/source/geo-encoded*

# ruff
.ruff_cache
doc/source/cube.joblib.compressed
doc/source/cube.pickle

cache/
cache/
122 changes: 71 additions & 51 deletions doc/source/io.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"xarray >= 2022.12.0",
"pyproj >= 3.0.0",
"shapely >= 2.0b1",
"cf_xarray >= 0.9.2",
]

[project.urls]
Expand Down
68 changes: 68 additions & 0 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,74 @@ def extract_points(
)
return result

def encode_cf(self) -> xr.Dataset:
"""Encode geometry variables and associated CRS with CF conventions"""
import cf_xarray as cfxr

ds = self._obj.copy()
coords = self.geom_coords_indexed

# TODO: this could use geoxarray, but is quite simple in any case
# Adapted from rioxarray
# preserve ordering for roundtripping
unique_crs = []
for _, xi in sorted(coords.xindexes.items()):
if xi.crs not in unique_crs:
unique_crs.append(xi.crs)
if len(unique_crs) == 1:
grid_mappings = {unique_crs.pop(): "spatial_ref"}
else:
grid_mappings = {
crs_: f"spatial_ref_{i}" for i, crs_ in enumerate(unique_crs)
}

for crs, grid_mapping in grid_mappings.items():
grid_mapping_attrs = crs.to_cf()
# TODO: not all CRS can be represented by CF grid_mappings
# For now, we allow this.
# if "grid_mapping_name" not in grid_mapping_attrs:
# raise ValueError
wkt_str = crs.to_wkt()
grid_mapping_attrs["spatial_ref"] = wkt_str
grid_mapping_attrs["crs_wkt"] = wkt_str
ds.coords[grid_mapping] = xr.Variable(
dims=(), data=0, attrs=grid_mapping_attrs
)

for name, coord in coords.items():
dims = set(coord.dims)
index = coords.xindexes[name]
varnames = (k for k, v in ds._variables.items() if dims & set(v.dims))
for name in varnames:
ds._variables[name].attrs["grid_mapping"] = grid_mappings[index.crs]
encoded = cfxr.geometry.encode_geometries(ds)
return encoded

def decode_cf(self) -> xr.Dataset:
import cf_xarray as cfxr

decoded = cfxr.geometry.decode_geometries(self._obj.copy())
crs = {
name: CRS.from_user_input(var.attrs["crs_wkt"])
for name, var in decoded._variables.items()
if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs
}
dims = decoded.xvec.geom_coords.dims
for dim in dims:
decoded = (
decoded.set_xindex(dim) if dim not in decoded._indexes else decoded
)
decoded = decoded.xvec.set_geom_indexes(
dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None))
)
Comment on lines +1370 to +1372
Copy link
Contributor Author

@dcherian dcherian Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key buggy line. it always sets the index, we do not record which geometry dims were indexed at encode-time. What should we do here?

As an aside it'd be nice for set_geom_indexes to understand the grid_mapping convention. WDYT?

One approach: decode_cf does NOT set the new index, but the user does so manually. Instead set_geom_indexes learns how to interpret the grid_mapping convention so CRS is set properly by default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it always sets the index, we do not record which geometry dims were indexed at encode-time. What should we do here?

Is that an issue if we just index all geom dims encoded in the file?

As an aside it'd be nice for set_geom_indexes to understand the grid_mapping convention. WDYT?

Not against but I don't really know what would it mean implementation-wise. Maybe just a simple call to pyproj.CRS.from_cf?

set_geom_indexes learns how to interpret the grid_mapping convention so CRS is set properly by default.

That would be preferable. Not a fan of asking users to set indexes after reading what already was indexed before writing.

for name in crs:
# remove spatial_ref so the coordinate system is only stored on the index
del decoded[name]
for var in decoded._variables.values():
if set(dims) & set(var.dims):
var.attrs.pop("grid_mapping", None)
return decoded


def _resolve_input(
positional: Mapping[Any, Any] | None,
Expand Down
43 changes: 41 additions & 2 deletions xvec/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def multi_dataset(geom_array, geom_array_z):

@pytest.fixture(scope="session")
def multi_geom_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -80,11 +80,32 @@ def multi_geom_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
Comment on lines +83 to +84
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you can't set these in GeometryIndex.create_variables?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from "no one thought about that until now", I am not aware of any.

return ds


@pytest.fixture(scope="session")
def multi_geom_multi_crs_dataset(geom_array, geom_array_z):
ds = (
xr.Dataset(
coords={
"geom": geom_array,
"geom_z": geom_array_z,
}
)
.drop_indexes(["geom", "geom_z"])
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs="EPSG:4362")
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
def multi_geom_no_index_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -96,6 +117,9 @@ def multi_geom_no_index_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -157,3 +181,18 @@ def traffic_dataset(geom_array):
"day": pd.date_range("2023-01-01", periods=10),
},
).xvec.set_geom_indexes(["origin", "destination"], crs=26915)


@pytest.fixture(
params=[
"first_geom_dataset",
"multi_dataset",
"multi_geom_dataset",
"multi_geom_no_index_dataset",
"multi_geom_multi_crs_dataset",
"traffic_dataset",
],
scope="session",
)
def all_datasets(request):
return request.getfixturevalue(request.param)
18 changes: 18 additions & 0 deletions xvec/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,21 @@ def test_extract_points_array():
geometry=4326
),
)


def test_cf_roundtrip(all_datasets):
ds = all_datasets
copy = ds.copy(deep=True)
encoded = ds.xvec.encode_cf()

if unique_crs := {
idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs
}:
nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs)
assert len(unique_crs) == nwkts

roundtripped = encoded.xvec.decode_cf()

xr.testing.assert_identical(ds, roundtripped)
# make sure we didn't modify the original dataset.
xr.testing.assert_identical(ds, copy)
Loading