Skip to content

Commit

Permalink
add custom Zarr _FillValue encoding / decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Oct 9, 2024
1 parent bd978b0 commit 34c4c24
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import base64
import functools
import json
import os
import struct
import warnings
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -58,6 +60,49 @@ def _zarr_v3() -> bool:
ZarrFormat = Literal[2, 3]


class FillValueCoder:
"""Handle custom logic to safely encode and decode fill values in Zarr.
Possibly redundant with logic in xarray/coding/variables.py but needs to be
isolated from NetCDF-specific logic.
"""

@classmethod
def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
if dtype.kind in "S":
# byte string
return base64.standard_b64encode(value).decode()
elif dtype.kind in "b":
# boolean
return bool(value)
elif dtype.kind in "iu":
# todo: do we want to check for decimals?
return int(value)
elif dtype.kind in "f":
return base64.standard_b64encode(struct.pack("<d", float(value))).decode()
elif dtype.kind in "U":
return str(value)
else:
raise ValueError(f"Failed to encode fill_value. Unsupported dtype {dtype}")

@classmethod
def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
if dtype == "string":
# zarr V3 string type
return str(value)
elif dtype == "bytes":
# zarr V3 bytes type
return base64.standard_b64decode(value)
np_dtype = np.dtype(dtype)
if np_dtype.kind in "f":
return struct.unpack("<d", base64.standard_b64decode(value))[0]
elif np_dtype.kind in "b":
return bool(value)
elif np_dtype.kind in "iu":
return int(value)
else:
raise ValueError(f"Failed to decode fill_value. Unsupported dtype {dtype}")


def encode_zarr_attr_value(value):
"""
Encode a attribute value as something that can be serialized as json
Expand All @@ -71,6 +116,10 @@ def encode_zarr_attr_value(value):
"""
if isinstance(value, np.ndarray):
encoded = value.tolist()
# elif isinstance(value, bytes):
# try to match how Zarr encodes bytes
# return [int.from_bytes(value)]
# return value.decode("utf-8")
# this checks if it's a scalar number
elif isinstance(value, np.generic):
encoded = value.item()
Expand Down Expand Up @@ -675,6 +724,11 @@ def open_store_variable(self, name, zarr_array=None):
# TODO: it feels a bit hacky to hijack CF decoding for this purpose
if zarr_array.fill_value is not None:
attributes["_FillValue"] = zarr_array.fill_value
elif "_FillValue" in attributes:
original_zarr_dtype = zarr_array.metadata.data_type
attributes["_FillValue"] = FillValueCoder.decode(
attributes["_FillValue"], original_zarr_dtype.value
)

return Variable(dimensions, data, attributes, encoding)

Expand Down Expand Up @@ -884,6 +938,11 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
v.encoding = {}
else:
fill_value = None
if "_FillValue" in attrs:
# replace with encoded fill value
attrs["_FillValue"] = FillValueCoder.encode(
attrs["_FillValue"], dtype
)

zarr_array = None
zarr_shape = None
Expand Down

0 comments on commit 34c4c24

Please sign in to comment.