Skip to content

Commit

Permalink
More geometry functionality
Browse files Browse the repository at this point in the history
1. Associate geometry vars as coordinates when we can
2. Add `cf.geometries`
3. Geometries in repr
4. Allow indexing by `"geometry"` or any geometry type.
  • Loading branch information
dcherian committed Jun 25, 2024
1 parent 59c3347 commit cb7f5d0
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 97 deletions.
140 changes: 111 additions & 29 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from . import sgrid
from .criteria import (
_DSG_ROLES,
_GEOMETRY_TYPES,
cf_role_criteria,
coordinate_criteria,
geometry_var_criteria,
grid_mapping_var_criteria,
regex,
)
Expand All @@ -39,6 +41,7 @@
_format_data_vars,
_format_dsg_roles,
_format_flags,
_format_geometries,
_format_sgrid,
_maybe_panel,
)
Expand Down Expand Up @@ -198,7 +201,9 @@ def _get_groupby_time_accessor(


def _get_custom_criteria(
obj: DataArray | Dataset, key: Hashable, criteria: Mapping | None = None
obj: DataArray | Dataset,
key: Hashable,
criteria: Iterable[Mapping] | Mapping | None = None,
) -> list[Hashable]:
"""
Translate from axis, coord, or custom name to variable name.
Expand Down Expand Up @@ -227,18 +232,16 @@ def _get_custom_criteria(
except ImportError:
from re import match as regex_match # type: ignore[no-redef]

if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
variables = obj._variables

if criteria is None:
if not OPTIONS["custom_criteria"]:
return []
criteria = OPTIONS["custom_criteria"]

if criteria is not None:
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
variables = obj._variables

criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
criteria_map = ChainMap(*criteria_iter)
results: set = set()
if key in criteria_map:
Expand Down Expand Up @@ -367,6 +370,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
return list(results)


def _parse_related_geometry_vars(attrs: Mapping) -> tuple[Hashable]:
names = itertools.chain(
*[
attrs.get(attr, "").split(" ")
for attr in [
"interior_ring",
"node_coordinates",
"node_count",
"part_node_count",
]
]
)
return tuple(n for n in names if n)


def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
Translate from key (either CF key or variable name) to its bounds' variable names.
Expand Down Expand Up @@ -470,8 +488,14 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
all_mappers: tuple[Mapper] = (
_get_custom_criteria,
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
functools.partial(
_get_custom_criteria,
criteria=(
cf_role_criteria,
grid_mapping_var_criteria,
geometry_var_criteria,
),
), # type: ignore[assignment]
_get_axis_coord,
_get_measure,
_get_grid_mapping_name,
Expand Down Expand Up @@ -821,6 +845,23 @@ def check_results(names, key):
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
)
for g in geometries
)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
Expand Down Expand Up @@ -1559,8 +1600,7 @@ def _generate_repr(self, rich=False):
_format_flags(self, rich), title="Flag Variable", rich=rich
)

roles = self.cf_roles
if roles:
if roles := self.cf_roles:
if any(role in roles for role in _DSG_ROLES):
yield _maybe_panel(
_format_dsg_roles(self, dims, rich),
Expand All @@ -1576,6 +1616,13 @@ def _generate_repr(self, rich=False):
rich=rich,
)

if self.geometries:
yield _maybe_panel(
_format_geometries(self, dims, rich),
title="Geometries",
rich=rich,
)

yield _maybe_panel(
_format_coordinates(self, dims, coords, rich),
title="Coordinates",
Expand Down Expand Up @@ -1755,12 +1802,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:

vardict: dict[str, list[Hashable]] = {}
for k, v in variables.items():
if "cf_role" in v.attrs:
role = v.attrs["cf_role"]
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
if role := attrs_or_encoding.get("cf_role", None):
vardict[role] = vardict.setdefault(role, []) + [k]

return {role_: sort_maybe_hashable(v) for role_, v in vardict.items()}

@property
def geometries(self) -> dict[str, list[Hashable]]:
"""
Mapping geometry type names to variable names.
Returns
-------
dict
Dictionary mapping geometry names to variable names.
References
----------
Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
"""
vardict: dict[str, list[Hashable]] = {}

if isinstance(self._obj, Dataset):
variables = self._obj._variables
elif isinstance(self._obj, DataArray):
variables = {"_": self._obj._variable}

for v in variables.values():
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
if geometry := attrs_or_encoding.get("geometry", None):
gtype = self._obj[geometry].attrs["geometry_type"]
vardict.setdefault(gtype, [])
if geometry not in vardict[gtype]:
vardict[gtype] += [geometry]
return {type_: sort_maybe_hashable(v) for type_, v in vardict.items()}

def get_associated_variable_names(
self, name: Hashable, skip_bounds: bool = False, error: bool = True
) -> dict[str, list[Hashable]]:
Expand Down Expand Up @@ -1795,15 +1872,15 @@ def get_associated_variable_names(
"bounds",
"grid_mapping",
"grid",
"geometry",
]

coords: dict[str, list[Hashable]] = {k: [] for k in keys}
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)

coordinates = attrs_or_encoding.get("coordinates", None)
# Handles case where the coordinates attribute is None
# This is used to tell xarray to not write a coordinates attribute
if coordinates:
if coordinates := attrs_or_encoding.get("coordinates", None):
coords["coordinates"] = coordinates.split(" ")

if "cell_measures" in attrs_or_encoding:
Expand All @@ -1822,27 +1899,32 @@ def get_associated_variable_names(
)
coords["cell_measures"] = []

if (
isinstance(self._obj, Dataset)
and "ancillary_variables" in attrs_or_encoding
if isinstance(self._obj, Dataset) and (
anc := attrs_or_encoding.get("ancillary_variables", None)
):
coords["ancillary_variables"] = attrs_or_encoding[
"ancillary_variables"
].split(" ")
coords["ancillary_variables"] = anc.split(" ")

if not skip_bounds:
if "bounds" in attrs_or_encoding:
coords["bounds"] = [attrs_or_encoding["bounds"]]
if bounds := attrs_or_encoding.get("bounds", None):
coords["bounds"] = [bounds]
for dim in self._obj[name].dims:
dbounds = self._obj[dim].attrs.get("bounds", None)
if dbounds:
if dbounds := self._obj[dim].attrs.get("bounds", None):
coords["bounds"].append(dbounds)

if "grid" in attrs_or_encoding:
coords["grid"] = [attrs_or_encoding["grid"]]
for attrname in ["grid", "grid_mapping"]:
if maybe := attrs_or_encoding.get(attrname, None):
coords[attrname] = [maybe]

if "grid_mapping" in attrs_or_encoding:
coords["grid_mapping"] = [attrs_or_encoding["grid_mapping"]]
more: Sequence[Hashable] = ()
if geometry_var := attrs_or_encoding.get("geometry", None):
coords["geometry"] = [geometry_var]
_attrs = ChainMap(
self._obj[geometry_var].attrs, self._obj[geometry_var].encoding
)
more = _parse_related_geometry_vars(_attrs)
elif "geometry_type" in attrs_or_encoding:
more = _parse_related_geometry_vars(attrs_or_encoding)
coords["geometry"].extend(more)

allvars = itertools.chain(*coords.values())
missing = set(allvars) - set(self._maybe_to_dataset()._variables)
Expand Down
9 changes: 9 additions & 0 deletions cf_xarray/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from collections.abc import Mapping, MutableMapping
from typing import Any

#: CF Roles understood by cf-xarray
_DSG_ROLES = ["timeseries_id", "profile_id", "trajectory_id"]
#: Geometry types understood by cf-xarray
_GEOMETRY_TYPES = ("line", "point", "polygon")

cf_role_criteria: Mapping[str, Mapping[str, str]] = {
k: {"cf_role": k}
Expand All @@ -31,6 +34,12 @@
"grid_mapping": {"grid_mapping_name": re.compile(".")}
}

# A geometry container is anything with a geometry_type attribute
geometry_var_criteria: Mapping[str, Mapping[str, Any]] = {
"geometry": {"geometry_type": re.compile(".")},
**{k: {"geometry_type": k} for k in _GEOMETRY_TYPES},
}

coordinate_criteria: MutableMapping[str, MutableMapping[str, tuple]] = {
"latitude": {
"standard_name": ("latitude",),
Expand Down
29 changes: 29 additions & 0 deletions cf_xarray/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,32 @@ def _create_inexact_bounds():
node_coordinates="node_lon node_lat node_elevation",
),
)


def point_dataset():
from shapely.geometry import MultiPoint, Point

da = xr.DataArray(
[
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
Point(3.0, 4.0),
Point(4.0, 5.0),
Point(3.0, 4.0),
],
dims=("index",),
name="geometry",
)
ds = da.to_dataset()
return ds


def encoded_point_dataset():
from .geometry import encode_geometries

ds = encode_geometries(point_dataset())
ds["data"] = (
"index",
np.arange(ds.sizes["index"]),
{"geometry": "geometry_container"},
)
return ds
11 changes: 11 additions & 0 deletions cf_xarray/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,17 @@ def _format_dsg_roles(accessor, dims, rich):
)


def _format_geometries(accessor, dims, rich):
yield make_text_section(
accessor,
"CF Geometries",
"geometries",
dims=dims,
# valid_keys=_DSG_ROLES,
rich=rich,
)


def _format_coordinates(accessor, dims, coords, rich):
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES

Expand Down
4 changes: 3 additions & 1 deletion cf_xarray/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def encode_geometries(ds: xr.Dataset):
for varname, var in ds._variables.items():
if varname == name:
continue
# TODO: this is incomplete. It works for vector data cubes where one of the geometry vars
# is a dimension coordinate.
if name in var.dims:
var = var.copy()
var._attrs = copy.deepcopy(var._attrs)
Expand Down Expand Up @@ -244,7 +246,7 @@ def reshape_unique_geometries(
out[geom_var] = ds[geom_var].isel({old_name: unique_indexes})
if old_name not in ds.coords:
# If there was no coord before, drop the dummy one we made.
out = out.drop_vars(old_name)
out = out.drop_vars(old_name) # type: ignore[arg-type,unused-ignore] # Hashable/str stuff
return out


Expand Down
55 changes: 55 additions & 0 deletions cf_xarray/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest
import xarray as xr


@pytest.fixture(scope="session")
def geometry_ds():
pytest.importorskip("shapely")

from shapely.geometry import MultiPoint, Point

# empty/fill workaround to avoid numpy deprecation(warning) due to the array interface of shapely geometries.
geoms = np.empty(4, dtype=object)
geoms[:] = [
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
Point(3.0, 4.0),
Point(4.0, 5.0),
Point(3.0, 4.0),
]

ds = xr.Dataset(
{
"data": xr.DataArray(
range(len(geoms)),
dims=("index",),
attrs={
"coordinates": "crd_x crd_y",
},
),
"time": xr.DataArray([0, 0, 0, 1], dims=("index",)),
}
)
shp_ds = ds.assign(geometry=xr.DataArray(geoms, dims=("index",)))
# Here, since it should not be present in shp_ds
ds.data.attrs["geometry"] = "geometry_container"

cf_ds = ds.assign(
x=xr.DataArray([1.0, 2.0, 3.0, 4.0, 3.0], dims=("node",), attrs={"axis": "X"}),
y=xr.DataArray([2.0, 3.0, 4.0, 5.0, 4.0], dims=("node",), attrs={"axis": "Y"}),
node_count=xr.DataArray([2, 1, 1, 1], dims=("index",)),
crd_x=xr.DataArray([1.0, 3.0, 4.0, 3.0], dims=("index",), attrs={"nodes": "x"}),
crd_y=xr.DataArray([2.0, 4.0, 5.0, 4.0], dims=("index",), attrs={"nodes": "y"}),
geometry_container=xr.DataArray(
attrs={
"geometry_type": "point",
"node_count": "node_count",
"node_coordinates": "x y",
"coordinates": "crd_x crd_y",
}
),
)

cf_ds = cf_ds.set_coords(["x", "y", "crd_x", "crd_y"])

return cf_ds, shp_ds
Loading

0 comments on commit cb7f5d0

Please sign in to comment.