Skip to content

Commit

Permalink
override units for datetime64/timedelta64 variables to preserve int…
Browse files Browse the repository at this point in the history
…eger dtype (#8201)

* remove `dtype` from encoding for datetime64/timedelta64 variables to prevent unnecessary casts

* adapt tests

* add whats-new.rst entry

* Update xarray/coding/times.py

Co-authored-by: Spencer Clark <[email protected]>

* Update doc/whats-new.rst

Co-authored-by: Spencer Clark <[email protected]>

* add test per review suggestion, replace .kind-check with np.issubdtype-check

* align timedelta64 check with datetime64 check

* override units instead of dtype

* remove print statement

* warn in case of serialization to floating point, too

* align if-else

* Add instructions to warnings

* Fix test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use warnings.catch_warnings

* Update doc/whats-new.rst

Co-authored-by: Spencer Clark <[email protected]>

---------

Co-authored-by: Spencer Clark <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 24, 2023
1 parent 77eaa8b commit a4f80b2
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 45 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Bug fixes
- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords
(:issue:`6528`, :pull:`8114`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
110 changes: 76 additions & 34 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray:
return num


def _division(deltas, delta, floor):
if floor:
# calculate int64 floor division
# to preserve integer dtype if possible (GH 4045, GH7817).
num = deltas // delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
num = deltas / delta
return num


def encode_cf_datetime(
dates, units: str | None = None, calendar: str | None = None
dates,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[np.ndarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
Expand Down Expand Up @@ -689,6 +703,12 @@ def encode_cf_datetime(
time_units, ref_date = _unpack_time_units_and_ref_date(units)
time_delta = _time_units_to_timedelta64(time_units)

# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
# an OverflowError is raised if the ref_date is too far away from
# dates to be encoded (GH 2272).
dates_as_index = pd.DatetimeIndex(dates.ravel())
time_deltas = dates_as_index - ref_date

# retrieve needed units to faithfully encode to int64
needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
if data_units != units:
Expand All @@ -697,26 +717,32 @@ def encode_cf_datetime(
if ref_delta > np.timedelta64(0, "ns"):
needed_units = _infer_time_units_from_diff(ref_delta)

# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
# an OverflowError is raised if the ref_date is too far away from
# dates to be encoded (GH 2272).
dates_as_index = pd.DatetimeIndex(dates.ravel())
time_deltas = dates_as_index - ref_date

# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
if time_delta <= needed_time_delta:
# calculate int64 floor division
# to preserve integer dtype if possible (GH 4045, GH7817).
num = time_deltas // time_delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
emit_user_level_warning(
f"Times can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing timeseries to floating point."
)
num = time_deltas / time_delta

floor_division = True
if time_delta > needed_time_delta:
floor_division = False
if dtype is None:
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. "
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Serializing with units {new_units!r} instead. "
f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
f"Set encoding['units'] to {new_units!r} to silence this warning ."
)
units = new_units
time_delta = needed_time_delta
floor_division = True

num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(dates.shape)

except (OutOfBoundsDatetime, OverflowError, ValueError):
Expand All @@ -728,7 +754,9 @@ def encode_cf_datetime(
return (num, units, calendar)


def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]:
def encode_cf_timedelta(
timedeltas, units: str | None = None, dtype: np.dtype | None = None
) -> tuple[np.ndarray, str]:
data_units = infer_timedelta_units(timedeltas)

if units is None:
Expand All @@ -744,18 +772,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra

# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
if time_delta <= needed_time_delta:
# calculate int64 floor division
# to preserve integer dtype if possible
num = time_deltas // time_delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing timedeltas to floating point."
)
num = time_deltas / time_delta

floor_division = True
if time_delta > needed_time_delta:
floor_division = False
if dtype is None:
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. "
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Serializing with units {needed_units!r} instead. "
f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
f"Set encoding['units'] to {needed_units!r} to silence this warning ."
)
units = needed_units
time_delta = needed_time_delta
floor_division = True

num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(timedeltas.shape)
return (num, units)

Expand All @@ -772,7 +811,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:

units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
(data, units, calendar) = encode_cf_datetime(data, units, calendar)
dtype = encoding.get("dtype", None)
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)
Expand Down Expand Up @@ -807,7 +847,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)

data, units = encode_cf_timedelta(data, encoding.pop("units", None))
data, units = encode_cf_timedelta(
data, encoding.pop("units", None), encoding.get("dtype", None)
)
safe_setitem(attrs, "units", units, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
Expand Down
65 changes: 54 additions & 11 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from xarray.coding.variables import SerializationWarning
from xarray.conventions import _update_bounds_attributes, cf_encoder
from xarray.core.common import contains_cftime_datetimes
from xarray.testing import assert_allclose, assert_equal, assert_identical
from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
FirstElementAccessibleArray,
arm_xfail,
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype(
pytest.skip("Nanosecond frequency is not valid for cftime dates.")
times = date_range("2000", periods=3, freq=freq)
units = f"{encoding_units} since 2000-01-01"
encoded, _, _ = coding.times.encode_cf_datetime(times, units)
encoded, _units, _ = coding.times.encode_cf_datetime(times, units)

numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units)
encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit)
Expand Down Expand Up @@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None:
("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False),
("1677-09-21T00:12:43.145225", "us", np.int64, None, False),
("1970-01-01T00:00:01.000001", "us", np.int64, None, False),
("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True),
],
)
def test_roundtrip_datetime64_nanosecond_precision(
Expand Down Expand Up @@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None:
]
units = "days since 1970-01-10T01:01:00"
needed_units = "hours"
encoding = dict(_FillValue=20, units=units)
new_units = f"{needed_units} since 1970-01-10T01:01:00"

encoding = dict(dtype=None, _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
wmsg = (
f"Times can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
)
with pytest.warns(UserWarning, match=wmsg):
with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.float64
assert encoded_var.attrs["units"] == units
assert encoded_var.attrs["_FillValue"] == 20.0

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="int64", _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
with pytest.warns(
UserWarning, match=f"Serializing with units {new_units!r} instead."
):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == new_units
assert encoded_var.attrs["_FillValue"] == 20

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="float64", _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
with warnings.catch_warnings():
warnings.simplefilter("error")
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.float64
assert encoded_var.attrs["units"] == units
assert encoded_var.attrs["_FillValue"] == 20.0

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="int64", _FillValue=20, units=new_units)
var = Variable(["time"], times, encoding=encoding)
with warnings.catch_warnings():
warnings.simplefilter("error")
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == new_units
assert encoded_var.attrs["_FillValue"] == 20

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)
Expand Down Expand Up @@ -1309,14 +1348,18 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None:
needed_units = "hours"
wmsg = (
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing with units {needed_units!r} instead."
)
encoding = dict(_FillValue=20, units=units)
encoding = dict(dtype=np.int64, _FillValue=20, units=units)
var = Variable(["time"], timedelta_values, encoding=encoding)
with pytest.warns(UserWarning, match=wmsg):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == needed_units
assert encoded_var.attrs["_FillValue"] == 20
decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_allclose(var, decoded_var)
assert_identical(var, decoded_var)
assert decoded_var.encoding["dtype"] == np.int64


def test_roundtrip_float_times() -> None:
Expand Down

0 comments on commit a4f80b2

Please sign in to comment.