From 03086723dd026970660f8d671575bd54bbef619a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 9 Jan 2025 09:24:08 +0100 Subject: [PATCH] refactor timedelta decoding to _numbers_to_timedelta and res-use it within decode_cf_timedelta --- xarray/coding/times.py | 91 ++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index ec1bca46a11..579585850c4 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -458,41 +458,12 @@ def _decode_datetime_with_pandas( elif flat_num_dates.dtype.kind in "f": flat_num_dates = flat_num_dates.astype(np.float64) - # keep NaT/nan mask - nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min) - - # in case we need to change the unit, we fix the numbers here - # this should be safe, as errors would have been raised above - ns_time_unit = _NS_PER_TIME_DELTA[time_unit] - ns_ref_date_unit = _NS_PER_TIME_DELTA[ref_date.unit] - if ns_time_unit > ns_ref_date_unit: - flat_num_dates *= np.int64(ns_time_unit / ns_ref_date_unit) - time_unit = ref_date.unit - - # estimate fitting resolution for floating point values - # this iterates until all floats are fractionless or time_unit == "ns" - if flat_num_dates.dtype.kind == "f" and time_unit != "ns": - flat_num_dates, new_time_unit = _check_higher_resolution( - flat_num_dates, time_unit - ) - if time_unit != new_time_unit: - msg = ( - f"Can't decode floating point datetime to {time_unit!r} without " - f"precision loss, decoding to {new_time_unit!r} instead. " - f"To silence this warning use time_unit={new_time_unit!r} in call to " - f"decoding function." - ) - emit_user_level_warning(msg, SerializationWarning) - time_unit = new_time_unit - - # Cast input ordinals to integers and properly handle NaN/NaT - # to prevent casting NaN to int - flat_num_dates_int = np.zeros_like(flat_num_dates, dtype=np.int64) - flat_num_dates_int[nan] = np.iinfo(np.int64).min - flat_num_dates_int[~nan] = flat_num_dates[~nan].astype(np.int64) + timedeltas = _numbers_to_timedelta( + flat_num_dates, time_unit, ref_date.unit, "datetime" + ) - # cast to timedelta64[time_unit] and add to ref_date - return ref_date + flat_num_dates_int.astype(f"timedelta64[{time_unit}]") + # add timedeltas to ref_date + return ref_date + timedeltas def decode_cf_datetime( @@ -590,6 +561,49 @@ def to_datetime_unboxed(value, **kwargs): return result +def _numbers_to_timedelta( + flat_num: np.ndarray, + time_unit: NPDatetimeUnitOptions, + ref_unit: PDDatetimeUnitOptions, + datatype: str, +) -> np.ndarray: + """Transform numbers to np.timedelta64.""" + # keep NaT/nan mask + nan = np.isnan(flat_num) | (flat_num == np.iinfo(np.int64).min) + + # in case we need to change the unit, we fix the numbers here + # this should be safe, as errors would have been raised above + ns_time_unit = _NS_PER_TIME_DELTA[time_unit] + ns_ref_date_unit = _NS_PER_TIME_DELTA[ref_unit] + if ns_time_unit > ns_ref_date_unit: + flat_num *= np.int64(ns_time_unit / ns_ref_date_unit) + time_unit = ref_unit + + # estimate fitting resolution for floating point values + # this iterates until all floats are fractionless or time_unit == "ns" + if flat_num.dtype.kind == "f" and time_unit != "ns": + flat_num_dates, new_time_unit = _check_higher_resolution(flat_num, time_unit) + if time_unit != new_time_unit: + msg = ( + f"Can't decode floating point {datatype} to {time_unit!r} without " + f"precision loss, decoding to {new_time_unit!r} instead. " + f"To silence this warning use time_unit={new_time_unit!r} in call to " + f"decoding function." + ) + emit_user_level_warning(msg, SerializationWarning) + time_unit = new_time_unit + + # Cast input ordinals to integers and properly handle NaN/NaT + # to prevent casting NaN to int + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + flat_num = flat_num.astype(np.int64) + flat_num[nan] = np.iinfo(np.int64).min + + # cast to wanted type + return flat_num.astype(f"timedelta64[{time_unit}]") + + def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray: # todo: check, if this works as intended """Given an array of numeric timedeltas in netCDF format, convert it into a @@ -597,14 +611,15 @@ def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray: """ num_timedeltas = np.asarray(num_timedeltas) unit = _netcdf_to_numpy_timeunit(units) + + timedeltas = _numbers_to_timedelta(num_timedeltas, unit, "s", "timedelta") + as_unit = unit if unit not in {"s", "ms", "us", "ns"}: # default to ns, when not specified as_unit = "ns" - result = ( - pd.to_timedelta(ravel(num_timedeltas), unit=unit).as_unit(as_unit).to_numpy() - ) - return reshape(result, num_timedeltas.shape) + result = pd.to_timedelta(ravel(timedeltas)).as_unit(as_unit).to_numpy() + return reshape(result, timedeltas.shape) def _unit_timedelta_cftime(units: str) -> timedelta: