Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for netCDF4.EnumType #8147

Merged
merged 45 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f1bc33b
ENH: make a light refactoring
Sep 12, 2023
4da8938
dirty commit
Sep 13, 2023
ab53970
Clean
Sep 13, 2023
75e00c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2023
e1d51e3
wip: fix tests
Sep 13, 2023
95e30b2
dirty
Sep 15, 2023
a3160c5
clean
Sep 15, 2023
59ef686
Remove dict from valid attrs type
Sep 15, 2023
d135be2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2023
8c12e50
Fix encoding
Sep 15, 2023
e8f4872
FIX: ordering of flags
Sep 15, 2023
f481e1f
FIX: encoding of the same enum twice (or more).
Sep 18, 2023
951ea32
DOC: Add note for Enum on to_netcdf
Sep 22, 2023
ec3c90a
ENH: Raise explicit error on invalid variable
Nov 8, 2023
dcc1254
Merge remote-tracking branch 'xarray-origin/main' into enh/add-enum-s…
Nov 8, 2023
55927f1
DOC: Update whats-new
Nov 8, 2023
9e9c62c
fix: move enum check
Nov 8, 2023
5f1bffc
FIX: unit test for min-all-deps requirements
Nov 9, 2023
5189c74
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Nov 10, 2023
2410c2e
ENH: Add enum discovery
Nov 10, 2023
4b966ba
ENH: Raise error instead of modifying dataset
Nov 10, 2023
9273a1d
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Dec 11, 2023
cbfadad
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Jan 5, 2024
ca043a7
FIX: pop unnecessary encoding
Jan 5, 2024
ee3dc00
Add Enum Coder
Jan 8, 2024
892b2b6
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 9, 2024
9ab1ad1
DOC: Update what's new
Jan 9, 2024
7219b99
FIX: Use EnumMeta instead of EnumType fo py<3.11
Jan 9, 2024
2aa119f
ENH: Improve error message
Jan 9, 2024
da43a10
Remove unnecessary test
Jan 9, 2024
096f021
Update enum Coder
Jan 9, 2024
26bb8ce
ENH: Update error handling of decoding
Jan 9, 2024
d21d73a
ENH: Avoid encoding enum to CF
Jan 9, 2024
81a4bec
ENH: encode netcdf4 enum within dtype
Jan 10, 2024
b114ccc
MAINT: Remove CF flag_* encoding
Jan 10, 2024
6376a13
Add assertion after roundtrip in enum tests
Jan 10, 2024
89a8751
add NativeEnumCoder, adapt tests
kmuehlbauer Jan 11, 2024
ac20a40
remove test-file
kmuehlbauer Jan 11, 2024
d515e0d
restructure datatype extraction
kmuehlbauer Jan 11, 2024
5c66563
use invalid_netcdf for h5netcdf tests
kmuehlbauer Jan 11, 2024
d62ac29
FIX: encoding typing
Jan 11, 2024
f834ede
Update xarray/backends/netCDF4_.py
kmuehlbauer Jan 14, 2024
9a3980a
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 15, 2024
2a3103f
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 16, 2024
f22046d
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ New Features

- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
- Open netCDF4 enums and turn them into CF flag_meanings/flag_values.
bzah marked this conversation as resolved.
Show resolved Hide resolved
This also gives a special meaning to the 'enum' attribute in DataArrays, when it is set, this tells the netCDF4 backend
to turn flag_meanings and flag_values into a netCDF4 Enum named using ``attrs["enum"]`` content.
By `Abel Aoun <https://github.com/bzah>_`(:issue:`8144`, :pull:`8147`)
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
By `Ben Mares <https://github.com/maresb>`_.
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the
Expand Down
Binary file added toto.nc
Binary file not shown.
7 changes: 6 additions & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from enum import EnumMeta
from functools import partial
from io import BytesIO
from numbers import Number
Expand Down Expand Up @@ -172,7 +173,7 @@ def _validate_attrs(dataset, invalid_netcdf=False):
`invalid_netcdf=True`.
"""

valid_types = (str, Number, np.ndarray, np.number, list, tuple)
valid_types = (str, Number, np.ndarray, np.number, list, tuple, EnumMeta)
bzah marked this conversation as resolved.
Show resolved Hide resolved
if invalid_netcdf:
valid_types += (np.bool_,)

Expand Down Expand Up @@ -407,6 +408,7 @@ def open_dataset(
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
decode_enum: bool | None = None,
**kwargs,
) -> Dataset:
"""Open and decode a dataset from a file or file-like object.
Expand Down Expand Up @@ -512,6 +514,8 @@ def open_dataset(
backend_kwargs: dict
Additional keyword arguments passed on to the engine open function,
equivalent to `**kwargs`.
decode_enum: bool, optional
If True, decode CF flag_values and flag_meanings into a pyton Enum.
**kwargs: dict
Additional keyword arguments passed on to the engine open function.
For example:
Expand Down Expand Up @@ -566,6 +570,7 @@ def open_dataset(
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
decode_enum=decode_enum,
)

overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
Expand Down
60 changes: 46 additions & 14 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from collections.abc import Iterable
from contextlib import suppress
from enum import Enum
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -49,7 +50,6 @@
# string used by netCDF4.
_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}


NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])


Expand Down Expand Up @@ -234,13 +234,13 @@ def _force_native_endianness(var):


def _extract_nc4_variable_encoding(
variable,
variable: Variable,
raise_on_invalid=False,
lsd_okay=True,
h5py_okay=False,
backend="netCDF4",
unlimited_dims=None,
):
) -> dict[str, Any]:
if unlimited_dims is None:
unlimited_dims = ()

Expand All @@ -257,6 +257,7 @@ def _extract_nc4_variable_encoding(
"_FillValue",
"dtype",
"compression",
"enum",
bzah marked this conversation as resolved.
Show resolved Hide resolved
"significant_digits",
"quantize_mode",
"blosc_shuffle",
Expand Down Expand Up @@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
return encoding


def _is_list_of_strings(value):
def _is_list_of_strings(value) -> bool:
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1

Expand Down Expand Up @@ -414,10 +415,14 @@ def _acquire(self, needs_lock=True):
def ds(self):
return self._acquire()

def open_store_variable(self, name, var):
def open_store_variable(self, name: str, var):
import netCDF4

dimensions = var.dimensions
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
if isinstance(var.datatype, netCDF4.EnumType):
attributes["enum"] = Enum(var.datatype.name, var.datatype.enum_dict)
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
encoding = {}
Expand Down Expand Up @@ -485,21 +490,20 @@ def encode_variable(self, variable):
return variable

def prepare_variable(
self, name, variable, check_encoding=False, unlimited_dims=None
self, name, variable: Variable, check_encoding=False, unlimited_dims=None
):
_ensure_no_forward_slash_in_name(name)

datatype = _get_datatype(
variable, self.format, raise_on_invalid_encoding=check_encoding
)
attrs = variable.attrs.copy()

fill_value = attrs.pop("_FillValue", None)

if attrs.get("enum"):
datatype = self._build_and_get_enum(name, attrs, variable.dtype)
else:
datatype = _get_datatype(
variable, self.format, raise_on_invalid_encoding=check_encoding
)
encoding = _extract_nc4_variable_encoding(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)

if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
Expand Down Expand Up @@ -527,6 +531,32 @@ def prepare_variable(

return target, variable.data

def _build_and_get_enum(
self, var_name: str, attributes: dict, dtype: np.dtype
) -> object:
enum = attributes.pop("enum")
enum_dict = {e.name: e.value for e in enum}
enum_name = enum.__name__
if enum_name in self.ds.enumtypes:
datatype = self.ds.enumtypes[enum_name]
if datatype.enum_dict != enum_dict:
error_msg = (
f"Cannot save variable `{var_name}` because an enum"
f" `{enum_name}` already exists in the Dataset but have"
" a different definition. To fix this error, make sure"
" each variable have a unique name in `attrs['enum']`"
" or, if they should share same enum type, make sure"
" the enums are identical."
)
raise ValueError(error_msg)
else:
datatype = self.ds.createEnumType(
dtype,
enum_name,
enum_dict,
)
return datatype

def sync(self):
self.ds.sync()

Expand Down Expand Up @@ -597,6 +627,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
persist=False,
lock=None,
autoclose=False,
decode_enum: bool | None = None,
) -> Dataset:
filename_or_obj = _normalize_path(filename_or_obj)
store = NetCDF4DataStore.open(
Expand All @@ -622,6 +653,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
decode_enum=decode_enum,
)
return ds

Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
decode_enum: bool | None = None,
) -> Dataset:
assert isinstance(filename_or_obj, AbstractDataStore)

Expand All @@ -53,6 +54,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
decode_enum=decode_enum,
)

ds = Dataset(vars, attrs=attrs)
Expand Down
28 changes: 27 additions & 1 deletion xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import warnings
from collections.abc import Hashable, MutableMapping
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union

Expand Down Expand Up @@ -566,11 +567,36 @@ def decode(self):

class ObjectVLenStringCoder(VariableCoder):
def encode(self):
return NotImplementedError
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
variable = variable.astype(variable.encoding["dtype"])
return variable
else:
return variable


class EnumCoder(VariableCoder):
"""Decode CF flag_* to python Enum"""

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
"""From CF flag_* to python Enum"""
dims, data, attrs, encoding = unpack_for_decoding(variable)
if (
attrs.get("enum")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no enum attribute on-disk for plain CF style enum. Testing for flag_meaning, flag_values should be sufficient. Problem: We would need to invent an enum name.

If we can decode from flag_meaning, flag_values to Python enum, it would be good to be able to roundtrip (implement .encoding Python enum to flag_meaning, flag_values).

Copy link
Contributor Author

@bzah bzah Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do we always want to decode CF flag_* into python enums ? Having this test also on "attrs['enum']" would avoid decoding them when this is not wanted.

As for the encoding, I had previously a EnumCoder::encode method that turns an python Enum into CF flag_*, but encoders are called before the netCDF4_::open_store_variable method. So within open_store_variable I would need to decode the flag_* again and turn them into a netCDF4 enum, which defies the goal of having a python Enum, no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do we always want to decode CF flag_* into python enums ? Having this test also on "attrs['enum']" would avoid decoding them when this is not wanted.

No, only when decode_enum=True.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the encoding, I had previously a EnumCoder::encode method that turns an python Enum into CF flag_*, but encoders are called before the netCDF4_::open_store_variable method. So within open_store_variable I would need to decode the flag_* again and turn them into a netCDF4 enum, which defies the goal of having a python Enum, no ?

From my point of view, there are two things here. Encoding/decoding python enum <-> CF flag_* and being able to read/write netCDF4 enum type.

The current approach using attrs["enum"] to represent netCDF4 enum in xarray is straightforward and works very well.

We might just keep the CF stuff for another PR and get the netCDF4 enum backend feature in. @dcherian, do you have suggestions to move forward here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with limiting to just the netCDF type, and skipping the flag_* stuff.

Another thought I just had is whether we can use the dtype attribute instead. But we need two pieces of information: the Enum dict and the dtype of the values. Shall we stick the Enum dict in dtype and preserve the dtype of the array as the dtype on disk? Is there a usecase for encoding the values as a different dtype at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can store the dtype within the enum with the type argument: encoding["dtype"] = Enum(e_name, e_dict, type=data.dtype.type) but I don't know what are the caveats of this approach.

We might just keep the CF stuff for another PR and get the netCDF4 enum backend feature in.
Agree with limiting to just the netCDF type, and skipping the flag_* stuff.

Alright, sorry for all these back and forths. I will try to focus on the netCDF Enum encoding/decoding.

Copy link
Contributor

@kmuehlbauer kmuehlbauer Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzah No need to apologize. :-) This stuff soon gets very complex and finding the best approach is based on multiple iterations.

I've thought a while on @dcherians comment above. We might be able to do something like this in open_store_variable:

    if isinstance(var.datatype, netCDF4.EnumType):
        dtype = np.dtype(dtype, metadata={'enum': var.datatype.enum_dict,
                                          'enum_name': var.datatype.name})

    encoding["dtype"] = dtype

This will use numpy's metadata to store the enum at the dtype. This is essentially the same as h5py is doing (without the enum_name). We could even provide that dtype to the variable itself. It's not that problematic to create a python enum from that dtype.

In prepare_variable we would just have to do the same thing analog to your current datatype extraction, but now from the variable.dtype

For that to work we would either need to implement EnumCoder.encoding to add back the metadata (it might have been stripped by processing) to the variable.dtype or fix NonStringCoder to do this (preferred, so we can reserve the EnumCoder for the flag_* stuff).

By using that approach we get the dtype and the enum (in dtype.metadata).

Update: fixed example code

and attrs.get("flag_meanings")
and attrs.get("flag_values")
):
flag_meanings = attrs.pop("flag_meanings")
flag_meanings = flag_meanings.split(" ")
flag_values = attrs.pop("flag_values")
flag_values = [int(v) for v in flag_values.split(", ")]
enum_name = attrs.pop("enum")
enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)}
attrs["enum"] = Enum(enum_name, enum_dict)
return Variable(dims, data, attrs, encoding, fastpath=True)
return variable
28 changes: 20 additions & 8 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


def _var_as_tuple(var: Variable) -> T_VarTuple:
bzah marked this conversation as resolved.
Show resolved Hide resolved
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
Expand Down Expand Up @@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
dims, data, attrs, encoding = _var_as_tuple(var)
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
Expand Down Expand Up @@ -162,7 +158,7 @@ def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
"""
Converts an Variable into an Variable which follows some
Converts a Variable into a Variable which follows some
of the CF conventions:

- Nans are masked using _FillValue (or the deprecated missing_value)
Expand Down Expand Up @@ -212,6 +208,7 @@ def decode_cf_variable(
stack_char_dim: bool = True,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
decode_enum: bool | None = None,
) -> Variable:
"""
Decodes a variable which may hold CF encoded information.
Expand Down Expand Up @@ -252,6 +249,8 @@ def decode_cf_variable(
represented using ``np.datetime64[ns]`` objects. If False, always
decode times to ``np.datetime64[ns]`` objects; if this is not possible
raise an error.
decode_enum: bool, optional
Turn the CF flag_values and flag_meanings into a python Enum in `attrs['enum']`.

Returns
-------
Expand Down Expand Up @@ -295,6 +294,9 @@ def decode_cf_variable(

var = variables.BooleanCoder().decode(var)

if decode_enum:
var = variables.EnumCoder().decode(var)

dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)

encoding.setdefault("dtype", original_dtype)
Expand Down Expand Up @@ -393,6 +395,7 @@ def decode_cf_variables(
drop_variables: T_DropVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
decode_enum: bool | None = None,
) -> tuple[T_Variables, T_Attrs, set[Hashable]]:
"""
Decode several CF encoded variables.
Expand Down Expand Up @@ -445,9 +448,10 @@ def stackable(dim: Hashable) -> bool:
stack_char_dim=stack_char_dim,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
decode_enum=decode_enum,
)
except Exception as e:
raise type(e)(f"Failed to decode variable {k!r}: {e}")
raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
if decode_coords in [True, "coordinates", "all"]:
var_attrs = new_vars[k].attrs
if "coordinates" in var_attrs:
Expand Down Expand Up @@ -509,6 +513,7 @@ def decode_cf(
drop_variables: T_DropVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
decode_enum: bool = True,
) -> Dataset:
"""Decode the given Dataset or Datastore according to CF conventions into
a new Dataset.
Expand Down Expand Up @@ -587,6 +592,7 @@ def decode_cf(
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
decode_enum=decode_enum,
)
ds = Dataset(vars, attrs=attrs)
ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
Expand All @@ -602,6 +608,7 @@ def cf_decoder(
concat_characters: bool = True,
mask_and_scale: bool = True,
decode_times: bool = True,
decode_enum: bool = True,
) -> tuple[T_Variables, T_Attrs]:
"""
Decode a set of CF encoded variables and attributes.
Expand Down Expand Up @@ -633,7 +640,12 @@ def cf_decoder(
decode_cf_variable
"""
variables, attributes, _ = decode_cf_variables(
variables, attributes, concat_characters, mask_and_scale, decode_times
variables,
attributes,
concat_characters,
mask_and_scale,
decode_times,
decode_enum=decode_enum,
)
return variables, attributes

Expand Down
5 changes: 5 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4069,6 +4069,11 @@ def to_netcdf(
name is the same as a coordinate name, then it is given the name
``"__xarray_dataarray_variable__"``.

[netCDF4 backend only] When the CF flag_values/flag_meanings attributes are
bzah marked this conversation as resolved.
Show resolved Hide resolved
set in for this DataArray, you can choose to replace these attributes by
a netcdf4 EnumType by updating the encoding dictionary with a key value pair
like: `da.attrs["enum"] = "enum_name"`.

See Also
--------
Dataset.to_netcdf
Expand Down
Loading
Loading