Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add geometry encoding and decoding functions. #517

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
from .options import set_options # noqa
from .utils import _get_version

from . import geometry # noqa

__version__ = _get_version()
140 changes: 111 additions & 29 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from . import sgrid
from .criteria import (
_DSG_ROLES,
_GEOMETRY_TYPES,
cf_role_criteria,
coordinate_criteria,
geometry_var_criteria,
grid_mapping_var_criteria,
regex,
)
Expand All @@ -39,6 +41,7 @@
_format_data_vars,
_format_dsg_roles,
_format_flags,
_format_geometries,
_format_sgrid,
_maybe_panel,
)
Expand Down Expand Up @@ -198,7 +201,9 @@ def _get_groupby_time_accessor(


def _get_custom_criteria(
obj: DataArray | Dataset, key: Hashable, criteria: Mapping | None = None
obj: DataArray | Dataset,
key: Hashable,
criteria: Iterable[Mapping] | Mapping | None = None,
) -> list[Hashable]:
"""
Translate from axis, coord, or custom name to variable name.
Expand Down Expand Up @@ -227,18 +232,16 @@ def _get_custom_criteria(
except ImportError:
from re import match as regex_match # type: ignore[no-redef]

if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
variables = obj._variables

if criteria is None:
if not OPTIONS["custom_criteria"]:
return []
criteria = OPTIONS["custom_criteria"]

if criteria is not None:
criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
variables = obj._variables

criteria_iter = always_iterable(criteria, allowed=(tuple, list, set))
criteria_map = ChainMap(*criteria_iter)
results: set = set()
if key in criteria_map:
Expand Down Expand Up @@ -367,6 +370,21 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
return list(results)


def _parse_related_geometry_vars(attrs: Mapping) -> tuple[Hashable]:
names = itertools.chain(
*[
attrs.get(attr, "").split(" ")
for attr in [
"interior_ring",
"node_coordinates",
"node_count",
"part_node_count",
]
]
)
return tuple(n for n in names if n)


def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
Translate from key (either CF key or variable name) to its bounds' variable names.
Expand Down Expand Up @@ -470,8 +488,14 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
all_mappers: tuple[Mapper] = (
_get_custom_criteria,
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
functools.partial(
_get_custom_criteria,
criteria=(
cf_role_criteria,
grid_mapping_var_criteria,
geometry_var_criteria,
),
), # type: ignore[assignment]
_get_axis_coord,
_get_measure,
_get_grid_mapping_name,
Expand Down Expand Up @@ -821,6 +845,23 @@ def check_results(names, key):
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
)
for g in geometries
)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
Expand Down Expand Up @@ -1559,8 +1600,7 @@ def _generate_repr(self, rich=False):
_format_flags(self, rich), title="Flag Variable", rich=rich
)

roles = self.cf_roles
if roles:
if roles := self.cf_roles:
if any(role in roles for role in _DSG_ROLES):
yield _maybe_panel(
_format_dsg_roles(self, dims, rich),
Expand All @@ -1576,6 +1616,13 @@ def _generate_repr(self, rich=False):
rich=rich,
)

if self.geometries:
yield _maybe_panel(
_format_geometries(self, dims, rich),
title="Geometries",
rich=rich,
)

yield _maybe_panel(
_format_coordinates(self, dims, coords, rich),
title="Coordinates",
Expand Down Expand Up @@ -1755,12 +1802,42 @@ def cf_roles(self) -> dict[str, list[Hashable]]:

vardict: dict[str, list[Hashable]] = {}
for k, v in variables.items():
if "cf_role" in v.attrs:
role = v.attrs["cf_role"]
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
if role := attrs_or_encoding.get("cf_role", None):
vardict[role] = vardict.setdefault(role, []) + [k]

return {role_: sort_maybe_hashable(v) for role_, v in vardict.items()}

@property
def geometries(self) -> dict[str, list[Hashable]]:
"""
Mapping geometry type names to variable names.

Returns
-------
dict
Dictionary mapping geometry names to variable names.

References
----------
Please refer to the CF conventions document : http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#coordinates-metadata
"""
vardict: dict[str, list[Hashable]] = {}

if isinstance(self._obj, Dataset):
variables = self._obj._variables
elif isinstance(self._obj, DataArray):
variables = {"_": self._obj._variable}

for v in variables.values():
attrs_or_encoding = ChainMap(v.attrs, v.encoding)
if geometry := attrs_or_encoding.get("geometry", None):
gtype = self._obj[geometry].attrs["geometry_type"]
vardict.setdefault(gtype, [])
if geometry not in vardict[gtype]:
vardict[gtype] += [geometry]
return {type_: sort_maybe_hashable(v) for type_, v in vardict.items()}

def get_associated_variable_names(
self, name: Hashable, skip_bounds: bool = False, error: bool = True
) -> dict[str, list[Hashable]]:
Expand Down Expand Up @@ -1795,15 +1872,15 @@ def get_associated_variable_names(
"bounds",
"grid_mapping",
"grid",
"geometry",
]

coords: dict[str, list[Hashable]] = {k: [] for k in keys}
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)

coordinates = attrs_or_encoding.get("coordinates", None)
# Handles case where the coordinates attribute is None
# This is used to tell xarray to not write a coordinates attribute
if coordinates:
if coordinates := attrs_or_encoding.get("coordinates", None):
coords["coordinates"] = coordinates.split(" ")

if "cell_measures" in attrs_or_encoding:
Expand All @@ -1822,27 +1899,32 @@ def get_associated_variable_names(
)
coords["cell_measures"] = []

if (
isinstance(self._obj, Dataset)
and "ancillary_variables" in attrs_or_encoding
if isinstance(self._obj, Dataset) and (
anc := attrs_or_encoding.get("ancillary_variables", None)
):
coords["ancillary_variables"] = attrs_or_encoding[
"ancillary_variables"
].split(" ")
coords["ancillary_variables"] = anc.split(" ")

if not skip_bounds:
if "bounds" in attrs_or_encoding:
coords["bounds"] = [attrs_or_encoding["bounds"]]
if bounds := attrs_or_encoding.get("bounds", None):
coords["bounds"] = [bounds]
for dim in self._obj[name].dims:
dbounds = self._obj[dim].attrs.get("bounds", None)
if dbounds:
if dbounds := self._obj[dim].attrs.get("bounds", None):
coords["bounds"].append(dbounds)

if "grid" in attrs_or_encoding:
coords["grid"] = [attrs_or_encoding["grid"]]
for attrname in ["grid", "grid_mapping"]:
if maybe := attrs_or_encoding.get(attrname, None):
coords[attrname] = [maybe]

if "grid_mapping" in attrs_or_encoding:
coords["grid_mapping"] = [attrs_or_encoding["grid_mapping"]]
more: Sequence[Hashable] = ()
if geometry_var := attrs_or_encoding.get("geometry", None):
coords["geometry"] = [geometry_var]
_attrs = ChainMap(
self._obj[geometry_var].attrs, self._obj[geometry_var].encoding
)
more = _parse_related_geometry_vars(_attrs)
elif "geometry_type" in attrs_or_encoding:
more = _parse_related_geometry_vars(attrs_or_encoding)
coords["geometry"].extend(more)

allvars = itertools.chain(*coords.values())
missing = set(allvars) - set(self._maybe_to_dataset()._variables)
Expand Down
9 changes: 9 additions & 0 deletions cf_xarray/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from collections.abc import Mapping, MutableMapping
from typing import Any

#: CF Roles understood by cf-xarray
_DSG_ROLES = ["timeseries_id", "profile_id", "trajectory_id"]
#: Geometry types understood by cf-xarray
_GEOMETRY_TYPES = ("line", "point", "polygon")

cf_role_criteria: Mapping[str, Mapping[str, str]] = {
k: {"cf_role": k}
Expand All @@ -31,6 +34,12 @@
"grid_mapping": {"grid_mapping_name": re.compile(".")}
}

# A geometry container is anything with a geometry_type attribute
geometry_var_criteria: Mapping[str, Mapping[str, Any]] = {
"geometry": {"geometry_type": re.compile(".")},
**{k: {"geometry_type": k} for k in _GEOMETRY_TYPES},
}

coordinate_criteria: MutableMapping[str, MutableMapping[str, tuple]] = {
"latitude": {
"standard_name": ("latitude",),
Expand Down
29 changes: 29 additions & 0 deletions cf_xarray/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,32 @@ def _create_inexact_bounds():
node_coordinates="node_lon node_lat node_elevation",
),
)


def point_dataset():
from shapely.geometry import MultiPoint, Point

da = xr.DataArray(
[
MultiPoint([(1.0, 2.0), (2.0, 3.0)]),
Point(3.0, 4.0),
Point(4.0, 5.0),
Point(3.0, 4.0),
],
dims=("index",),
name="geometry",
)
ds = da.to_dataset()
return ds


def encoded_point_dataset():
from .geometry import encode_geometries

ds = encode_geometries(point_dataset())
ds["data"] = (
"index",
np.arange(ds.sizes["index"]),
{"geometry": "geometry_container"},
)
return ds
11 changes: 11 additions & 0 deletions cf_xarray/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,17 @@ def _format_dsg_roles(accessor, dims, rich):
)


def _format_geometries(accessor, dims, rich):
yield make_text_section(
accessor,
"CF Geometries",
"geometries",
dims=dims,
# valid_keys=_DSG_ROLES,
rich=rich,
)


def _format_coordinates(accessor, dims, coords, rich):
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES

Expand Down
Loading
Loading