Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for netCDF4.EnumType #8147

Merged
merged 45 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f1bc33b
ENH: make a light refactoring
Sep 12, 2023
4da8938
dirty commit
Sep 13, 2023
ab53970
Clean
Sep 13, 2023
75e00c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2023
e1d51e3
wip: fix tests
Sep 13, 2023
95e30b2
dirty
Sep 15, 2023
a3160c5
clean
Sep 15, 2023
59ef686
Remove dict from valid attrs type
Sep 15, 2023
d135be2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2023
8c12e50
Fix encoding
Sep 15, 2023
e8f4872
FIX: ordering of flags
Sep 15, 2023
f481e1f
FIX: encoding of the same enum twice (or more).
Sep 18, 2023
951ea32
DOC: Add note for Enum on to_netcdf
Sep 22, 2023
ec3c90a
ENH: Raise explicit error on invalid variable
Nov 8, 2023
dcc1254
Merge remote-tracking branch 'xarray-origin/main' into enh/add-enum-s…
Nov 8, 2023
55927f1
DOC: Update whats-new
Nov 8, 2023
9e9c62c
fix: move enum check
Nov 8, 2023
5f1bffc
FIX: unit test for min-all-deps requirements
Nov 9, 2023
5189c74
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Nov 10, 2023
2410c2e
ENH: Add enum discovery
Nov 10, 2023
4b966ba
ENH: Raise error instead of modifying dataset
Nov 10, 2023
9273a1d
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Dec 11, 2023
cbfadad
Merge remote-tracking branch 'origin/main' into enh/add-enum-support
Jan 5, 2024
ca043a7
FIX: pop unnecessary encoding
Jan 5, 2024
ee3dc00
Add Enum Coder
Jan 8, 2024
892b2b6
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 9, 2024
9ab1ad1
DOC: Update what's new
Jan 9, 2024
7219b99
FIX: Use EnumMeta instead of EnumType fo py<3.11
Jan 9, 2024
2aa119f
ENH: Improve error message
Jan 9, 2024
da43a10
Remove unnecessary test
Jan 9, 2024
096f021
Update enum Coder
Jan 9, 2024
26bb8ce
ENH: Update error handling of decoding
Jan 9, 2024
d21d73a
ENH: Avoid encoding enum to CF
Jan 9, 2024
81a4bec
ENH: encode netcdf4 enum within dtype
Jan 10, 2024
b114ccc
MAINT: Remove CF flag_* encoding
Jan 10, 2024
6376a13
Add assertion after roundtrip in enum tests
Jan 10, 2024
89a8751
add NativeEnumCoder, adapt tests
kmuehlbauer Jan 11, 2024
ac20a40
remove test-file
kmuehlbauer Jan 11, 2024
d515e0d
restructure datatype extraction
kmuehlbauer Jan 11, 2024
5c66563
use invalid_netcdf for h5netcdf tests
kmuehlbauer Jan 11, 2024
d62ac29
FIX: encoding typing
Jan 11, 2024
f834ede
Update xarray/backends/netCDF4_.py
kmuehlbauer Jan 14, 2024
9a3980a
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 15, 2024
2a3103f
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 16, 2024
f22046d
Merge branch 'main' into enh/add-enum-support
kmuehlbauer Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ New Features

- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata.
If multiple variables share the same enum in netCDF4, each dataarray will have its own
enum definition in their respective dtype metadata.
By `Abel Aoun <https://github.com/bzah>_`(:issue:`8144`, :pull:`8147`)
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
By `Ben Mares <https://github.com/maresb>`_.
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the
Expand Down
74 changes: 58 additions & 16 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
# string used by netCDF4.
_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}


NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])


Expand Down Expand Up @@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype):
)


def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
def _get_datatype(
var, nc_format="NETCDF4", raise_on_invalid_encoding=False
) -> np.dtype:
if nc_format == "NETCDF4":
return _nc4_dtype(var)
if "dtype" in var.encoding:
Expand Down Expand Up @@ -234,13 +235,13 @@ def _force_native_endianness(var):


def _extract_nc4_variable_encoding(
variable,
variable: Variable,
raise_on_invalid=False,
lsd_okay=True,
h5py_okay=False,
backend="netCDF4",
unlimited_dims=None,
):
) -> dict[str, Any]:
if unlimited_dims is None:
unlimited_dims = ()

Expand Down Expand Up @@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
return encoding


def _is_list_of_strings(value):
def _is_list_of_strings(value) -> bool:
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1

Expand Down Expand Up @@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True):
def ds(self):
return self._acquire()

def open_store_variable(self, name, var):
def open_store_variable(self, name: str, var):
import netCDF4

dimensions = var.dimensions
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
encoding: dict[str, Any] = {}
Copy link
Contributor Author

@bzah bzah Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A TypedDict for encoding and its possible values would be cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICS this is taken care of in in #8520. So if #8520 goes in first, we should change here.

if isinstance(var.datatype, netCDF4.EnumType):
encoding["dtype"] = np.dtype(
data.dtype,
metadata={
"enum": var.datatype.enum_dict,
"enum_name": var.datatype.name,
},
)
else:
encoding["dtype"] = var.dtype
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
encoding = {}
filters = var.filters()
if filters is not None:
encoding.update(filters)
Expand All @@ -440,7 +453,6 @@ def open_store_variable(self, name, var):
# save source so __repr__ can detect if it's local or not
encoding["source"] = self._filename
encoding["original_shape"] = var.shape
encoding["dtype"] = var.dtype

return Variable(dimensions, data, attributes, encoding)

Expand Down Expand Up @@ -485,21 +497,24 @@ def encode_variable(self, variable):
return variable

def prepare_variable(
self, name, variable, check_encoding=False, unlimited_dims=None
self, name, variable: Variable, check_encoding=False, unlimited_dims=None
):
_ensure_no_forward_slash_in_name(name)

attrs = variable.attrs.copy()
fill_value = attrs.pop("_FillValue", None)
datatype = _get_datatype(
variable, self.format, raise_on_invalid_encoding=check_encoding
)
attrs = variable.attrs.copy()

fill_value = attrs.pop("_FillValue", None)

# check enum metadata and use netCDF4.EnumType
if (
bzah marked this conversation as resolved.
Show resolved Hide resolved
(meta := np.dtype(datatype).metadata)
and (e_name := meta.get("enum_name"))
and (e_dict := meta.get("enum"))
):
datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
encoding = _extract_nc4_variable_encoding(
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)

if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
Expand Down Expand Up @@ -527,6 +542,33 @@ def prepare_variable(

return target, variable.data

def _build_and_get_enum(
self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
) -> Any:
"""
Add or get the netCDF4 Enum based on the dtype in encoding.
The return type should be ``netCDF4.EnumType``,
but we avoid importing netCDF4 globally for performances.
"""
if enum_name not in self.ds.enumtypes:
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
return self.ds.createEnumType(
dtype,
enum_name,
enum_dict,
)
datatype = self.ds.enumtypes[enum_name]
if datatype.enum_dict != enum_dict:
error_msg = (
f"Cannot save variable `{var_name}` because an enum"
f" `{enum_name}` already exists in the Dataset but have"
" a different definition. To fix this error, make sure"
" each variable have a uniquely named enum in their"
" `encoding['dtype'].metadata` or, if they should share same"
" the same enum type, make sure the enums are identical."
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
)
raise ValueError(error_msg)
return datatype

def sync(self):
self.ds.sync()

Expand Down
21 changes: 20 additions & 1 deletion xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,30 @@ def decode(self):

class ObjectVLenStringCoder(VariableCoder):
def encode(self):
return NotImplementedError
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
variable = variable.astype(variable.encoding["dtype"])
return variable
else:
return variable


class NativeEnumCoder(VariableCoder):
"""Encode Enum into variable dtype metadata."""

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if (
"dtype" in variable.encoding
and np.dtype(variable.encoding["dtype"]).metadata
and "enum" in variable.encoding["dtype"].metadata
):
dims, data, attrs, encoding = unpack_for_encoding(variable)
data = data.astype(dtype=variable.encoding.pop("dtype"))
return Variable(dims, data, attrs, encoding, fastpath=True)
else:
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
raise NotImplementedError()
17 changes: 9 additions & 8 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


def _var_as_tuple(var: Variable) -> T_VarTuple:
bzah marked this conversation as resolved.
Show resolved Hide resolved
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
Expand Down Expand Up @@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
dims, data, attrs, encoding = _var_as_tuple(var)
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
Expand Down Expand Up @@ -162,7 +158,7 @@ def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
"""
Converts an Variable into an Variable which follows some
Converts a Variable into a Variable which follows some
of the CF conventions:

- Nans are masked using _FillValue (or the deprecated missing_value)
Expand All @@ -188,6 +184,7 @@ def encode_cf_variable(
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(),
variables.UnsignedIntegerCoder(),
variables.NativeEnumCoder(),
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
Expand Down Expand Up @@ -447,7 +444,7 @@ def stackable(dim: Hashable) -> bool:
decode_timedelta=decode_timedelta,
)
except Exception as e:
raise type(e)(f"Failed to decode variable {k!r}: {e}")
raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
if decode_coords in [True, "coordinates", "all"]:
var_attrs = new_vars[k].attrs
if "coordinates" in var_attrs:
Expand Down Expand Up @@ -633,7 +630,11 @@ def cf_decoder(
decode_cf_variable
"""
variables, attributes, _ = decode_cf_variables(
variables, attributes, concat_characters, mask_and_scale, decode_times
variables,
attributes,
concat_characters,
mask_and_scale,
decode_times,
)
return variables, attributes

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4069,6 +4069,9 @@ def to_netcdf(
name is the same as a coordinate name, then it is given the name
``"__xarray_dataarray_variable__"``.

[netCDF4 backend only] netCDF4 enums are decoded into the
dataarray dtype metadata.

See Also
--------
Dataset.to_netcdf
Expand Down
120 changes: 120 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,126 @@ def test_raise_on_forward_slashes_in_names(self) -> None:
with self.roundtrip(ds):
pass

@requires_netCDF4
def test_encoding_enum__no_fill_value(self):
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
v = nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=None,
)
v[:] = 1
with open_dataset(tmp_file) as original:
save_kwargs = {}
if self.engine == "h5netcdf":
save_kwargs["invalid_netcdf"] = True
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
assert_equal(original, actual)
assert (
actual.clouds.encoding["dtype"].metadata["enum"]
== cloud_type_dict
)
if self.engine != "h5netcdf":
# not implemented in h5netcdf yet
assert (
actual.clouds.encoding["dtype"].metadata["enum_name"]
== "cloud_type"
)

@requires_netCDF4
def test_encoding_enum__multiple_variable_with_enum(self):
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=255,
)
nc.createVariable(
"tifa",
cloud_type,
"time",
fill_value=255,
)
with open_dataset(tmp_file) as original:
save_kwargs = {}
if self.engine == "h5netcdf":
save_kwargs["invalid_netcdf"] = True
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
assert_equal(original, actual)
assert (
actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"]
)
assert (
actual.clouds.encoding["dtype"].metadata
== actual.tifa.encoding["dtype"].metadata
)
assert (
actual.clouds.encoding["dtype"].metadata["enum"]
== cloud_type_dict
)
if self.engine != "h5netcdf":
# not implemented in h5netcdf yet
assert (
actual.clouds.encoding["dtype"].metadata["enum_name"]
== "cloud_type"
)

@requires_netCDF4
def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
"""
Given 2 variables, if they share the same enum type,
the 2 enum definition should be identical.
"""
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=255,
)
nc.createVariable(
"tifa",
cloud_type,
"time",
fill_value=255,
)
with open_dataset(tmp_file) as original:
assert (
original.clouds.encoding["dtype"].metadata
== original.tifa.encoding["dtype"].metadata
)
modified_enum = original.clouds.encoding["dtype"].metadata["enum"]
modified_enum.update({"neblig": 2})
original.clouds.encoding["dtype"] = np.dtype(
"u1",
metadata={"enum": modified_enum, "enum_name": "cloud_type"},
)
if self.engine != "h5netcdf":
# not implemented yet in h5netcdf
with pytest.raises(
ValueError,
match=(
"Cannot save variable .*"
" because an enum `cloud_type` already exists in the Dataset .*"
),
):
with self.roundtrip(original):
pass


@requires_netCDF4
class TestNetCDF4Data(NetCDF4Base):
Expand Down
Loading