From a4f80b23d32e9c3986e3342182fe382d8081c3c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Sun, 24 Sep 2023 17:05:25 +0200 Subject: [PATCH] override `units` for datetime64/timedelta64 variables to preserve integer dtype (#8201) * remove `dtype` from encoding for datetime64/timedelta64 variables to prevent unnecessary casts * adapt tests * add whats-new.rst entry * Update xarray/coding/times.py Co-authored-by: Spencer Clark * Update doc/whats-new.rst Co-authored-by: Spencer Clark * add test per review suggestion, replace .kind-check with np.issubdtype-check * align timedelta64 check with datetime64 check * override units instead of dtype * remove print statement * warn in case of serialization to floating point, too * align if-else * Add instructions to warnings * Fix test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use warnings.catch_warnings * Update doc/whats-new.rst Co-authored-by: Spencer Clark --------- Co-authored-by: Spencer Clark Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 2 + xarray/coding/times.py | 110 +++++++++++++++++++++--------- xarray/tests/test_coding_times.py | 65 +++++++++++++++--- 3 files changed, 132 insertions(+), 45 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 67429ed7e18..5f18e999cc0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -90,6 +90,8 @@ Bug fixes - ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords (:issue:`6528`, :pull:`8114`) By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 79efbecfb7c..2822f02dd8d 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray: return num +def _division(deltas, delta, floor): + if floor: + # calculate int64 floor division + # to preserve integer dtype if possible (GH 4045, GH7817). + num = deltas // delta.astype(np.int64) + num = num.astype(np.int64, copy=False) + else: + num = deltas / delta + return num + + def encode_cf_datetime( - dates, units: str | None = None, calendar: str | None = None + dates, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, ) -> tuple[np.ndarray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -689,6 +703,12 @@ def encode_cf_datetime( time_units, ref_date = _unpack_time_units_and_ref_date(units) time_delta = _time_units_to_timedelta64(time_units) + # Wrap the dates in a DatetimeIndex to do the subtraction to ensure + # an OverflowError is raised if the ref_date is too far away from + # dates to be encoded (GH 2272). + dates_as_index = pd.DatetimeIndex(dates.ravel()) + time_deltas = dates_as_index - ref_date + # retrieve needed units to faithfully encode to int64 needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units) if data_units != units: @@ -697,26 +717,32 @@ def encode_cf_datetime( if ref_delta > np.timedelta64(0, "ns"): needed_units = _infer_time_units_from_diff(ref_delta) - # Wrap the dates in a DatetimeIndex to do the subtraction to ensure - # an OverflowError is raised if the ref_date is too far away from - # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) - time_deltas = dates_as_index - ref_date - # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible (GH 4045, GH7817). - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timeseries to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + new_units = f"{needed_units} since {format_timestamp(ref_date)}" + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Serializing with units {new_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {new_units!r} to silence this warning ." + ) + units = new_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): @@ -728,7 +754,9 @@ def encode_cf_datetime( return (num, units, calendar) -def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]: +def encode_cf_timedelta( + timedeltas, units: str | None = None, dtype: np.dtype | None = None +) -> tuple[np.ndarray, str]: data_units = infer_timedelta_units(timedeltas) if units is None: @@ -744,18 +772,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timedeltas to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {needed_units!r} to silence this warning ." + ) + units = needed_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(timedeltas.shape) return (num, units) @@ -772,7 +811,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: units = encoding.pop("units", None) calendar = encoding.pop("calendar", None) - (data, units, calendar) = encode_cf_datetime(data, units, calendar) + dtype = encoding.get("dtype", None) + (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype) safe_setitem(attrs, "units", units, name=name) safe_setitem(attrs, "calendar", calendar, name=name) @@ -807,7 +847,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + data, units = encode_cf_timedelta( + data, encoding.pop("units", None), encoding.get("dtype", None) + ) safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 079e432b565..5f76a4a2ca8 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -30,7 +30,7 @@ from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes -from xarray.testing import assert_allclose, assert_equal, assert_identical +from xarray.testing import assert_equal, assert_identical from xarray.tests import ( FirstElementAccessibleArray, arm_xfail, @@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None: ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), ], ) def test_roundtrip_datetime64_nanosecond_precision( @@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None: ] units = "days since 1970-01-10T01:01:00" needed_units = "hours" - encoding = dict(_FillValue=20, units=units) + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) var = Variable(["time"], times, encoding=encoding) - wmsg = ( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - ) - with pytest.warns(UserWarning, match=wmsg): + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) assert_identical(var, decoded_var) @@ -1309,14 +1348,18 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: needed_units = "hours" wmsg = ( f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " + f"Serializing with units {needed_units!r} instead." ) - encoding = dict(_FillValue=20, units=units) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) var = Variable(["time"], timedelta_values, encoding=encoding) with pytest.warns(UserWarning, match=wmsg): encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) - assert_allclose(var, decoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 def test_roundtrip_float_times() -> None: