From 37d351034c9b8ff5158ff0bfea3652e8bcc5ce0d Mon Sep 17 00:00:00 2001 From: tvo Date: Sat, 2 Nov 2024 19:46:20 -0600 Subject: [PATCH] Allow wrapping astropy.units.Quantity --- xarray/core/variable.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 13053faff58..4216e574312 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -320,10 +320,17 @@ def convert_non_numpy_type(data): else: data = np.asarray(data) + _is_array_like = isinstance(data, np.ndarray | np.generic) + _is_nep18 = hasattr(data, "__array_function__") + _has_array_api = hasattr(data, "__array_namespace__") + _has_unit = hasattr(data, "_unit") + + # Allow `astropy.units.Quantity` + if _is_array_like and (_is_nep18 or _has_array_api) and _has_unit: + return cast("T_DuckArray", data) + # immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars - if not isinstance(data, np.ndarray | np.generic) and ( - hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") - ): + if not _is_array_like and (_is_nep18 or _has_array_api): return cast("T_DuckArray", data) # validate whether the data is valid data types. Also, explicitly cast `numpy`