Skip to content

Commit

Permalink
Merge pull request pydata#1 from rabernat/ryan/fix/zarr-3
Browse files Browse the repository at this point in the history
Fill value fixes for V3
  • Loading branch information
TomAugspurger authored Oct 9, 2024
2 parents 9b3c288 + 118e50e commit e6e2066
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 48 deletions.
134 changes: 123 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import base64
import functools
import json
import os
import struct
import warnings
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -58,6 +60,49 @@ def _zarr_v3() -> bool:
ZarrFormat = Literal[2, 3]


class FillValueCoder:
"""Handle custom logic to safely encode and decode fill values in Zarr.
Possibly redundant with logic in xarray/coding/variables.py but needs to be
isolated from NetCDF-specific logic.
"""

@classmethod
def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
if dtype.kind in "S":
# byte string
return base64.standard_b64encode(value).decode()
elif dtype.kind in "b":
# boolean
return bool(value)
elif dtype.kind in "iu":
# todo: do we want to check for decimals?
return int(value)
elif dtype.kind in "f":
return base64.standard_b64encode(struct.pack("<d", float(value))).decode()
elif dtype.kind in "U":
return str(value)
else:
raise ValueError(f"Failed to encode fill_value. Unsupported dtype {dtype}")

@classmethod
def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
if dtype == "string":
# zarr V3 string type
return str(value)
elif dtype == "bytes":
# zarr V3 bytes type
return base64.standard_b64decode(value)
np_dtype = np.dtype(dtype)
if np_dtype.kind in "f":
return struct.unpack("<d", base64.standard_b64decode(value))[0]
elif np_dtype.kind in "b":
return bool(value)
elif np_dtype.kind in "iu":
return int(value)
else:
raise ValueError(f"Failed to decode fill_value. Unsupported dtype {dtype}")


def encode_zarr_attr_value(value):
"""
Encode a attribute value as something that can be serialized as json
Expand All @@ -71,6 +116,10 @@ def encode_zarr_attr_value(value):
"""
if isinstance(value, np.ndarray):
encoded = value.tolist()
# elif isinstance(value, bytes):
# try to match how Zarr encodes bytes
# return [int.from_bytes(value)]
# return value.decode("utf-8")
# this checks if it's a scalar number
elif isinstance(value, np.generic):
encoded = value.item()
Expand Down Expand Up @@ -412,6 +461,7 @@ def _validate_datatypes_for_zarr_append(vname, existing_var, new_var):
or np.issubdtype(new_var.dtype, np.datetime64)
or np.issubdtype(new_var.dtype, np.bool_)
or new_var.dtype == object
or (new_var.dtype.kind in ("S", "U") and existing_var.dtype == object)
):
# We can skip dtype equality checks under two conditions: (1) if the var to append is
# new to the dataset, because in this case there is no existing var to compare it to;
Expand Down Expand Up @@ -493,6 +543,7 @@ class ZarrStore(AbstractWritableDataStore):
"_safe_chunks",
"_write_empty",
"_close_store_on_close",
"_use_zarr_fill_value_as_mask",
)

@classmethod
Expand All @@ -512,10 +563,16 @@ def open_store(
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
):

zarr_group, consolidate_on_close, close_store_on_close = _get_open_params(
(
zarr_group,
consolidate_on_close,
close_store_on_close,
use_zarr_fill_value_as_mask,
) = _get_open_params(
store=store,
mode=mode,
synchronizer=synchronizer,
Expand All @@ -526,6 +583,7 @@ def open_store(
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
)
group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)]
Expand All @@ -539,6 +597,7 @@ def open_store(
safe_chunks,
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
)
for group in group_paths
}
Expand All @@ -560,10 +619,16 @@ def open_group(
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
):

zarr_group, consolidate_on_close, close_store_on_close = _get_open_params(
(
zarr_group,
consolidate_on_close,
close_store_on_close,
use_zarr_fill_value_as_mask,
) = _get_open_params(
store=store,
mode=mode,
synchronizer=synchronizer,
Expand All @@ -574,6 +639,7 @@ def open_group(
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
)

Expand All @@ -586,6 +652,7 @@ def open_group(
safe_chunks,
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
)

def __init__(
Expand All @@ -598,6 +665,7 @@ def __init__(
safe_chunks=True,
write_empty: bool | None = None,
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
Expand All @@ -609,7 +677,8 @@ def __init__(
self._write_region = write_region
self._safe_chunks = safe_chunks
self._write_empty = write_empty
self._close_store_on_close = close_store_on_close
self._close_store_on_close = (close_store_on_close,)
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask

@property
def ds(self):
Expand Down Expand Up @@ -652,10 +721,16 @@ def open_store_variable(self, name, zarr_array=None):
}
)

# _FillValue needs to be in attributes, not encoding, so it will get
# picked up by decode_cf
if zarr_array.fill_value is not None:
attributes["_FillValue"] = zarr_array.fill_value
if self._use_zarr_fill_value_as_mask:
# Setting this attribute triggers CF decoding for missing values
# TODO: it feels a bit hacky to hijack CF decoding for this purpose
if zarr_array.fill_value is not None:
attributes["_FillValue"] = zarr_array.fill_value
elif "_FillValue" in attributes:
original_zarr_dtype = zarr_array.metadata.data_type
attributes["_FillValue"] = FillValueCoder.decode(
attributes["_FillValue"], original_zarr_dtype.value
)

return Variable(dimensions, data, attributes, encoding)

Expand Down Expand Up @@ -859,9 +934,17 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
dtype = v.dtype
shape = v.shape

fill_value = attrs.pop("_FillValue", None)
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}
if self._use_zarr_fill_value_as_mask:
fill_value = attrs.pop("_FillValue", None)
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}
else:
fill_value = None
if "_FillValue" in attrs:
# replace with encoded fill value
attrs["_FillValue"] = FillValueCoder.encode(
attrs["_FillValue"], dtype
)

zarr_array = None
zarr_shape = None
Expand Down Expand Up @@ -1087,6 +1170,7 @@ def open_zarr(
use_cftime=None,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
**kwargs,
Expand Down Expand Up @@ -1185,6 +1269,11 @@ def open_zarr(
of None will attempt to determine the zarr version from ``store`` when
possible, otherwise defaulting to the default version used by
the zarr-python library installed.
use_zarr_fill_value_as_mask : bool, optional
If True, use the zarr Array `fill_value` to mask the data, the same as done
for NetCDF data with `_FillValue` or `missing_value` attributes. If False,
the `fill_value` is ignored and the data are not masked. If None, this defaults
to True for `zarr_version=2` and False for `zarr_version=3`.
chunked_array_type: str, optional
Which chunked array type to coerce this datasets' arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system.
Expand Down Expand Up @@ -1257,6 +1346,7 @@ def open_zarr(
decode_timedelta=decode_timedelta,
use_cftime=use_cftime,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
)
return ds

Expand Down Expand Up @@ -1308,6 +1398,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
zarr_format=None,
store=None,
engine=None,
use_zarr_fill_value_as_mask=None,
) -> Dataset:
filename_or_obj = _normalize_path(filename_or_obj)
if not store:
Expand All @@ -1322,6 +1413,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
)

Expand Down Expand Up @@ -1474,6 +1566,7 @@ def _get_open_params(
storage_options,
stacklevel,
zarr_version,
use_zarr_fill_value_as_mask,
zarr_format,
):
import zarr
Expand Down Expand Up @@ -1541,7 +1634,26 @@ def _get_open_params(
else:
zarr_group = zarr.open_group(store, **open_kwargs)
close_store_on_close = zarr_group.store is not store
return zarr_group, consolidate_on_close, close_store_on_close

# we use this to determine how to handle fill_value
is_zarr_v3_format: bool
if _zarr_v3():
is_zarr_v3_format = zarr_group.metadata.zarr_format == 3
else:
is_zarr_v3_format = False
if use_zarr_fill_value_as_mask is None:
if is_zarr_v3_format:
# for new data, we use a better default
use_zarr_fill_value_as_mask = False
else:
# this was the default for v2 and shold apply to most existing Zarr data
use_zarr_fill_value_as_mask = True
return (
zarr_group,
consolidate_on_close,
close_store_on_close,
use_zarr_fill_value_as_mask,
)


def _handle_zarr_version_or_format(
Expand Down
13 changes: 13 additions & 0 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,19 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
return variable


class Numpy2StringDTypeCoder(VariableCoder):
# Convert Numpy 2 StringDType arrays to object arrays for backwards compatibility
# TODO: remove this if / when we decide to allow StringDType arrays in Xarray
def encode(self):
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype.kind == "T":
return variable.astype(object)
else:
return variable


class NativeEnumCoder(VariableCoder):
"""Encode Enum into variable dtype metadata."""

Expand Down
3 changes: 3 additions & 0 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def decode_cf_variable(
var = variables.ObjectVLenStringCoder().decode(var)
original_dtype = var.dtype

if original_dtype.kind == "T":
var = variables.Numpy2StringDTypeCoder().decode(var)

if mask_and_scale:
for coder in [
variables.CFMaskCoder(),
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if isdtype(dtype, "real floating"):
if np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
fill_value = np.nan
elif isdtype(dtype, "real floating"):
dtype_ = dtype
fill_value = np.nan
elif np.issubdtype(dtype, np.timedelta64):
Expand Down
Loading

0 comments on commit e6e2066

Please sign in to comment.