Skip to content

Commit

Permalink
WIP: naive implementation of enum
Browse files Browse the repository at this point in the history
TODO:
- fix to_netcdf
  It doesn't work if we have missing values because xarray doesn't
  use masked array so we try to assign a fill_value
  but netCDF4 forbid assigning values outside the
  enum valid range, thus crashing.
- Add unit tests
- Validate with xarray team if it's ok to add attribute to Variable
  and DataArray
- Add implementation for other backends ?
  • Loading branch information
Abel Aoun committed Sep 5, 2023
1 parent f13da94 commit bdfa8ce
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 69 deletions.
23 changes: 19 additions & 4 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,16 @@ def ds(self):
return self._acquire()

def open_store_variable(self, name, var):
import netCDF4

dimensions = var.dimensions
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
enum_meaning = None
enum_name = None
if isinstance(var.datatype, netCDF4.EnumType):
enum_meaning = var.datatype.enum_dict
enum_name = var.datatype.name
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
encoding = {}
Expand All @@ -434,8 +441,9 @@ def open_store_variable(self, name, var):
encoding["source"] = self._filename
encoding["original_shape"] = var.shape
encoding["dtype"] = var.dtype

return Variable(dimensions, data, attributes, encoding)
return Variable(dimensions, data, attributes, encoding,
enum_meaning= enum_meaning,
enum_name=enum_name)

def get_variables(self):
return FrozenDict(
Expand Down Expand Up @@ -478,7 +486,7 @@ def encode_variable(self, variable):
return variable

def prepare_variable(
self, name, variable, check_encoding=False, unlimited_dims=None
self, name, variable: Variable, check_encoding=False, unlimited_dims=None
):
_ensure_no_forward_slash_in_name(name)

Expand All @@ -503,12 +511,19 @@ def prepare_variable(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)

enum = None
if variable.enum_meaning:
enum = self.ds.createEnumType(
variable.dtype,
variable.enum_name,
variable.enum_meaning)

if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
nc4_var = self.ds.createVariable(
varname=name,
datatype=datatype,
datatype=enum if enum else datatype,
dimensions=variable.dims,
zlib=encoding.get("zlib", False),
complevel=encoding.get("complevel", 4),
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def guess_can_open(

def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
filename_or_obj: AbstractDataStore,
*,
mask_and_scale=True,
decode_times=True,
Expand Down
23 changes: 11 additions & 12 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
lazy_elemwise_func,
pop_to,
safe_setitem,
unpack_for_decoding,
unpack_for_encoding,
unpack,
)
from xarray.core import indexing
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(self, allows_unicode=True):
self.allows_unicode = allows_unicode

def encode(self, variable, name=None):
dims, data, attrs, encoding = unpack_for_encoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

contains_unicode = is_unicode_dtype(data.dtype)
encode_as_char = encoding.get("dtype") == "S1"
Expand All @@ -69,17 +68,17 @@ def encode(self, variable, name=None):
# TODO: figure out how to handle this in a lazy way with dask
data = encode_string_array(data, string_encoding)

return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding, enum_meaning=enum_meaning, enum_name=enum_name)

def decode(self, variable, name=None):
dims, data, attrs, encoding = unpack_for_decoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

if "_Encoding" in attrs:
string_encoding = pop_to(attrs, encoding, "_Encoding")
func = partial(decode_bytes_array, encoding=string_encoding)
data = lazy_elemwise_func(data, func, np.dtype(object))

return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding, enum_meaning=enum_meaning, enum_name=enum_name)


def decode_bytes_array(bytes_array, encoding="utf-8"):
Expand All @@ -97,11 +96,11 @@ def encode_string_array(string_array, encoding="utf-8"):

def ensure_fixed_length_bytes(var):
"""Ensure that a variable with vlen bytes is converted to fixed width."""
dims, data, attrs, encoding = unpack_for_encoding(var)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(var)
if check_vlen_dtype(data.dtype) == bytes:
# TODO: figure out how to handle this with dask
data = np.asarray(data, dtype=np.string_)
return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding, enum_meaning=enum_meaning, enum_name=enum_name)


class CharacterArrayCoder(VariableCoder):
Expand All @@ -110,24 +109,24 @@ class CharacterArrayCoder(VariableCoder):
def encode(self, variable, name=None):
variable = ensure_fixed_length_bytes(variable)

dims, data, attrs, encoding = unpack_for_encoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)
if data.dtype.kind == "S" and encoding.get("dtype") is not str:
data = bytes_to_char(data)
if "char_dim_name" in encoding.keys():
char_dim_name = encoding.pop("char_dim_name")
else:
char_dim_name = f"string{data.shape[-1]}"
dims = dims + (char_dim_name,)
return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding, enum_meaning=enum_meaning, enum_name=enum_name)

def decode(self, variable, name=None):
dims, data, attrs, encoding = unpack_for_decoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

if data.dtype == "S1" and dims:
encoding["char_dim_name"] = dims[-1]
dims = dims[:-1]
data = char_to_bytes(data)
return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding, enum_meaning=enum_meaning, enum_name=enum_name)


def bytes_to_char(arr):
Expand Down
19 changes: 9 additions & 10 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
lazy_elemwise_func,
pop_to,
safe_setitem,
unpack_for_decoding,
unpack_for_encoding,
unpack,
)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
Expand Down Expand Up @@ -702,22 +701,22 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(
variable.data.dtype, np.datetime64
) or contains_cftime_datetimes(variable):
dims, data, attrs, encoding = unpack_for_encoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

(data, units, calendar) = encode_cf_datetime(
data, encoding.pop("units", None), encoding.pop("calendar", None)
)
safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
return Variable(dims, data, attrs, encoding, fastpath=True, enum_meaning=enum_meaning, enum_name=enum_name)
else:
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
units = variable.attrs.get("units", None)
if isinstance(units, str) and "since" in units:
dims, data, attrs, encoding = unpack_for_decoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

units = pop_to(attrs, encoding, "units")
calendar = pop_to(attrs, encoding, "calendar")
Expand All @@ -730,33 +729,33 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
)
data = lazy_elemwise_func(data, transform, dtype)

return Variable(dims, data, attrs, encoding, fastpath=True)
return Variable(dims, data, attrs, encoding, fastpath=True, enum_meaning=enum_meaning, enum_name=enum_name)
else:
return variable


class CFTimedeltaCoder(VariableCoder):
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

data, units = encode_cf_timedelta(data, encoding.pop("units", None))
safe_setitem(attrs, "units", units, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
return Variable(dims, data, attrs, encoding, fastpath=True, enum_meaning=enum_meaning, enum_name=enum_name)
else:
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
units = variable.attrs.get("units", None)
if isinstance(units, str) and units in TIME_UNITS:
dims, data, attrs, encoding = unpack_for_decoding(variable)
dims, data, attrs, encoding, enum_meaning, enum_name = unpack(variable)

units = pop_to(attrs, encoding, "units")
transform = partial(decode_cf_timedelta, units=units)
dtype = np.dtype("timedelta64[ns]")
data = lazy_elemwise_func(data, transform, dtype=dtype)

return Variable(dims, data, attrs, encoding, fastpath=True)
return Variable(dims, data, attrs, encoding, fastpath=True, enum_meaning=enum_meaning, enum_name=enum_name)
else:
return variable
Loading

0 comments on commit bdfa8ce

Please sign in to comment.