diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 04fe88e9993..d2a4b32a71f 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -104,7 +104,7 @@ Internal Changes
~~~~~~~~~~~~~~~~
- Added integration tests against `pint `_.
- (:pull:`3238`) by `Justus Magin `_.
+ (:pull:`3238`, :pull:`3447`) by `Justus Magin `_.
.. note::
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index 80063f8b4bc..8eed1f0dbe3 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -123,14 +123,19 @@ def extract_units(obj):
def strip_units(obj):
if isinstance(obj, xr.Dataset):
- data_vars = {name: strip_units(value) for name, value in obj.data_vars.items()}
- coords = {name: strip_units(value) for name, value in obj.coords.items()}
+ data_vars = {
+ strip_units(name): strip_units(value)
+ for name, value in obj.data_vars.items()
+ }
+ coords = {
+ strip_units(name): strip_units(value) for name, value in obj.coords.items()
+ }
new_obj = xr.Dataset(data_vars=data_vars, coords=coords)
elif isinstance(obj, xr.DataArray):
data = array_strip_units(obj.data)
coords = {
- name: (
+ strip_units(name): (
(value.dims, array_strip_units(value.data))
if isinstance(value.data, Quantity)
else value # to preserve multiindexes
@@ -138,9 +143,13 @@ def strip_units(obj):
for name, value in obj.coords.items()
}
- new_obj = xr.DataArray(name=obj.name, data=data, coords=coords, dims=obj.dims)
- elif hasattr(obj, "magnitude"):
+ new_obj = xr.DataArray(
+ name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims
+ )
+ elif isinstance(obj, unit_registry.Quantity):
new_obj = obj.magnitude
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(strip_units(elem) for elem in obj)
else:
new_obj = obj
@@ -191,6 +200,38 @@ def attach_units(obj, units):
return new_obj
+def convert_units(obj, to):
+ if isinstance(obj, xr.Dataset):
+ data_vars = {
+ name: convert_units(array, to) for name, array in obj.data_vars.items()
+ }
+ coords = {name: convert_units(array, to) for name, array in obj.coords.items()}
+
+ new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
+ elif isinstance(obj, xr.DataArray):
+ name = obj.name
+
+ new_units = (
+ to.get(name, None) or to.get("data", None) or to.get(None, None) or 1
+ )
+ data = convert_units(obj.data, {None: new_units})
+
+ coords = {
+ name: (array.dims, convert_units(array.data, to))
+ for name, array in obj.coords.items()
+ if name != obj.name
+ }
+
+ new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs)
+ elif isinstance(obj, unit_registry.Quantity):
+ units = to.get(None)
+ new_obj = obj.to(units) if units is not None else obj
+ else:
+ new_obj = obj
+
+ return new_obj
+
+
def assert_equal_with_units(a, b):
# works like xr.testing.assert_equal, but also explicitly checks units
# so, it is more like assert_identical
@@ -1632,3 +1673,1696 @@ def test_grouped_operations(self, func, dtype):
result = func(data_array.groupby("y"))
assert_equal_with_units(expected, result)
+
+
+class TestDataset:
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="same_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "shared",
+ (
+ "nothing",
+ pytest.param(
+ "dims",
+ marks=pytest.mark.xfail(reason="reindex does not work with pint yet"),
+ ),
+ pytest.param(
+ "coords",
+ marks=pytest.mark.xfail(reason="reindex does not work with pint yet"),
+ ),
+ ),
+ )
+ def test_init(self, shared, unit, error, dtype):
+ original_unit = unit_registry.m
+ scaled_unit = unit_registry.mm
+
+ a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa
+ b = np.linspace(-1, 0, 12).astype(dtype) * unit_registry.Pa
+
+ raw_x = np.arange(a.shape[0])
+ x = raw_x * original_unit
+ x2 = x.to(scaled_unit)
+
+ raw_y = np.arange(b.shape[0])
+ y = raw_y * unit
+ y_units = unit if isinstance(y, unit_registry.Quantity) else None
+ if isinstance(y, unit_registry.Quantity):
+ if y.check(scaled_unit):
+ y2 = y.to(scaled_unit)
+ else:
+ y2 = y * 1000
+ y2_units = y2.units
+ else:
+ y2 = y * 1000
+ y2_units = None
+
+ variants = {
+ "nothing": ({"x": x, "x2": ("x", x2)}, {"y": y, "y2": ("y", y2)}),
+ "dims": (
+ {"x": x, "x2": ("x", strip_units(x2))},
+ {"x": y, "y2": ("x", strip_units(y2))},
+ ),
+ "coords": ({"x": raw_x, "y": ("x", x2)}, {"x": raw_y, "y": ("x", y2)}),
+ }
+ coords_a, coords_b = variants.get(shared)
+
+ dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x")
+
+ arr1 = xr.DataArray(data=a, coords=coords_a, dims=dims_a)
+ arr2 = xr.DataArray(data=b, coords=coords_b, dims=dims_b)
+ if error is not None and shared != "nothing":
+ with pytest.raises(error):
+ xr.Dataset(data_vars={"a": arr1, "b": arr2})
+
+ return
+
+ result = xr.Dataset(data_vars={"a": arr1, "b": arr2})
+
+ expected_units = {
+ "a": a.units,
+ "b": b.units,
+ "x": x.units,
+ "x2": x2.units,
+ "y": y_units,
+ "y2": y2_units,
+ }
+ expected = attach_units(
+ xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}),
+ expected_units,
+ )
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize(
+ "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr"))
+ )
+ @pytest.mark.parametrize(
+ "variant",
+ (
+ pytest.param(
+ "with_dims",
+ marks=pytest.mark.xfail(reason="units in indexes are not supported"),
+ ),
+ pytest.param("with_coords"),
+ pytest.param("without_coords"),
+ ),
+ )
+ @pytest.mark.filterwarnings("error:::pint[.*]")
+ def test_repr(self, func, variant, dtype):
+ array1 = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.Pa
+ array2 = np.linspace(0, 1, 10, dtype=dtype) * unit_registry.degK
+
+ x = np.arange(len(array1)) * unit_registry.s
+ y = x.to(unit_registry.ms)
+
+ variants = {
+ "with_dims": {"x": x},
+ "with_coords": {"y": ("x", y)},
+ "without_coords": {},
+ }
+
+ data_array = xr.Dataset(
+ data_vars={"a": ("x", array1), "b": ("x", array2)},
+ coords=variants.get(variant),
+ )
+
+ # FIXME: this just checks that the repr does not raise
+ # warnings or errors, but does not check the result
+ func(data_array)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ function("all"),
+ marks=pytest.mark.xfail(reason="not implemented by pint"),
+ ),
+ pytest.param(
+ function("any"),
+ marks=pytest.mark.xfail(reason="not implemented by pint"),
+ ),
+ function("argmax"),
+ function("argmin"),
+ function("max"),
+ function("min"),
+ function("mean"),
+ pytest.param(
+ function("median"),
+ marks=pytest.mark.xfail(
+ reason="np.median does not work with dataset yet"
+ ),
+ ),
+ pytest.param(
+ function("sum"),
+ marks=pytest.mark.xfail(
+ reason="np.result_type not implemented by pint"
+ ),
+ ),
+ pytest.param(
+ function("prod"),
+ marks=pytest.mark.xfail(reason="not implemented by pint"),
+ ),
+ function("std"),
+ function("var"),
+ function("cumsum"),
+ pytest.param(
+ function("cumprod"),
+ marks=pytest.mark.xfail(
+ reason="pint does not support cumprod on non-dimensionless yet"
+ ),
+ ),
+ pytest.param(
+ method("all"), marks=pytest.mark.xfail(reason="not implemented by pint")
+ ),
+ pytest.param(
+ method("any"), marks=pytest.mark.xfail(reason="not implemented by pint")
+ ),
+ method("argmax"),
+ method("argmin"),
+ method("max"),
+ method("min"),
+ method("mean"),
+ method("median"),
+ pytest.param(
+ method("sum"),
+ marks=pytest.mark.xfail(
+ reason="np.result_type not implemented by pint"
+ ),
+ ),
+ pytest.param(
+ method("prod"),
+ marks=pytest.mark.xfail(reason="not implemented by pint"),
+ ),
+ method("std"),
+ method("var"),
+ method("cumsum"),
+ pytest.param(
+ method("cumprod"),
+ marks=pytest.mark.xfail(
+ reason="pint does not support cumprod on non-dimensionless yet"
+ ),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_aggregation(self, func, dtype):
+ unit_a = unit_registry.Pa
+ unit_b = unit_registry.kg / unit_registry.m ** 3
+ a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x")
+ b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x")
+ x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x")
+ y = xr.DataArray(
+ data=np.arange(10, 20).astype(dtype) * unit_registry.s, dims="x"
+ )
+
+ ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y})
+
+ result = func(ds)
+ expected = attach_units(
+ func(strip_units(ds)),
+ {"a": array_extract_units(func(a)), "b": array_extract_units(func(b))},
+ )
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize("property", ("imag", "real"))
+ def test_numpy_properties(self, property, dtype):
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(
+ data=np.linspace(0, 1, 10) * unit_registry.Pa, dims="x"
+ ),
+ "b": xr.DataArray(
+ data=np.linspace(-1, 0, 15) * unit_registry.Pa, dims="y"
+ ),
+ },
+ coords={
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(15) * unit_registry.s,
+ },
+ )
+ units = extract_units(ds)
+
+ result = getattr(ds, property)
+ expected = attach_units(getattr(strip_units(ds), property), units)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("astype", float),
+ method("conj"),
+ method("argsort"),
+ method("conjugate"),
+ method("round"),
+ pytest.param(
+ method("rank", dim="x"),
+ marks=pytest.mark.xfail(reason="pint does not implement rank yet"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_numpy_methods(self, func, dtype):
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(
+ data=np.linspace(1, -1, 10) * unit_registry.Pa, dims="x"
+ ),
+ "b": xr.DataArray(
+ data=np.linspace(-1, 1, 15) * unit_registry.Pa, dims="y"
+ ),
+ },
+ coords={
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(15) * unit_registry.s,
+ },
+ )
+ units = {
+ "a": array_extract_units(func(ds.a)),
+ "b": array_extract_units(func(ds.b)),
+ "x": unit_registry.m,
+ "y": unit_registry.s,
+ }
+
+ result = func(ds)
+ expected = attach_units(func(strip_units(ds)), units)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr)
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_numpy_methods_with_args(self, func, unit, error, dtype):
+ data_unit = unit_registry.m
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=np.arange(10) * data_unit, dims="x"),
+ "b": xr.DataArray(data=np.arange(15) * data_unit, dims="y"),
+ },
+ coords={
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(15) * unit_registry.s,
+ },
+ )
+ units = extract_units(ds)
+
+ def strip(value):
+ return (
+ value.magnitude if isinstance(value, unit_registry.Quantity) else value
+ )
+
+ def convert(value, to):
+ if isinstance(value, unit_registry.Quantity) and value.check(to):
+ return value.to(to)
+
+ return value
+
+ scalar_types = (int, float)
+ kwargs = {
+ key: (value * unit if isinstance(value, scalar_types) else value)
+ for key, value in func.kwargs.items()
+ }
+
+ stripped_kwargs = {
+ key: strip(convert(value, data_unit)) for key, value in kwargs.items()
+ }
+
+ if error is not None:
+ with pytest.raises(error):
+ func(ds, **kwargs)
+
+ return
+
+ result = func(ds, **kwargs)
+ expected = attach_units(func(strip_units(ds), **stripped_kwargs), units)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize(
+ "func", (method("isnull"), method("notnull"), method("count")), ids=repr
+ )
+ def test_missing_value_detection(self, func, dtype):
+ array1 = (
+ np.array(
+ [
+ [1.4, 2.3, np.nan, 7.2],
+ [np.nan, 9.7, np.nan, np.nan],
+ [2.1, np.nan, np.nan, 4.6],
+ [9.9, np.nan, 7.2, 9.1],
+ ]
+ )
+ * unit_registry.degK
+ )
+ array2 = (
+ np.array(
+ [
+ [np.nan, 5.7, 12.0, 7.2],
+ [np.nan, 12.4, np.nan, 4.2],
+ [9.8, np.nan, 4.6, 1.4],
+ [7.2, np.nan, 6.3, np.nan],
+ [8.4, 3.9, np.nan, np.nan],
+ ]
+ )
+ * unit_registry.Pa
+ )
+
+ x = np.arange(array1.shape[0]) * unit_registry.m
+ y = np.arange(array1.shape[1]) * unit_registry.m
+ z = np.arange(array2.shape[0]) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("z", "x")),
+ },
+ coords={"x": x, "y": y, "z": z},
+ )
+
+ expected = func(strip_units(ds))
+ result = func(ds)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="ffill and bfill lose the unit")
+ @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr)
+ def test_missing_value_filling(self, func, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.degK
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.Pa
+ )
+
+ x = np.arange(len(array1))
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ expected = attach_units(
+ func(strip_units(ds), dim="x"),
+ {"a": unit_registry.degK, "b": unit_registry.Pa},
+ )
+ result = func(ds, dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="fillna drops the unit")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(
+ 1,
+ DimensionalityError,
+ id="no_unit",
+ marks=pytest.mark.xfail(reason="blocked by the failing `where`"),
+ ),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "fill_value",
+ (
+ pytest.param(
+ -1,
+ id="python scalar",
+ marks=pytest.mark.xfail(
+ reason="python scalar cannot be converted using astype()"
+ ),
+ ),
+ pytest.param(np.array(-1), id="numpy scalar"),
+ pytest.param(np.array([-1]), id="numpy array"),
+ ),
+ )
+ def test_fillna(self, fill_value, unit, error, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.m
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.m
+ )
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ }
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.fillna(value=fill_value * unit)
+
+ return
+
+ result = ds.fillna(value=fill_value * unit)
+ expected = attach_units(
+ strip_units(ds).fillna(value=fill_value),
+ {"a": unit_registry.m, "b": unit_registry.m},
+ )
+
+ assert_equal_with_units(expected, result)
+
+ def test_dropna(self, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.degK
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.Pa
+ )
+ x = np.arange(len(array1))
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ expected = attach_units(
+ strip_units(ds).dropna(dim="x"),
+ {"a": unit_registry.degK, "b": unit_registry.Pa},
+ )
+ result = ds.dropna(dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="pint does not implement `numpy.isin`")
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="same_unit"),
+ ),
+ )
+ def test_isin(self, unit, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.m
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.m
+ )
+ x = np.arange(len(array1))
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype)
+ values = raw_values * unit
+
+ if (
+ isinstance(values, unit_registry.Quantity)
+ and values.check(unit_registry.m)
+ and unit != unit_registry.m
+ ):
+ raw_values = values.to(unit_registry.m).magnitude
+
+ expected = strip_units(ds).isin(raw_values)
+ if not isinstance(values, unit_registry.Quantity) or not values.check(
+ unit_registry.m
+ ):
+ expected.a[:] = False
+ expected.b[:] = False
+ result = ds.isin(values)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize(
+ "variant",
+ (
+ pytest.param(
+ "masking",
+ marks=pytest.mark.xfail(
+ reason="np.result_type not implemented by quantity"
+ ),
+ ),
+ pytest.param(
+ "replacing_scalar",
+ marks=pytest.mark.xfail(
+ reason="python scalar not convertible using astype"
+ ),
+ ),
+ pytest.param(
+ "replacing_array",
+ marks=pytest.mark.xfail(
+ reason="replacing using an array drops the units"
+ ),
+ ),
+ pytest.param(
+ "dropping",
+ marks=pytest.mark.xfail(reason="nan not compatible with quantity"),
+ ),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="same_unit"),
+ ),
+ )
+ def test_where(self, variant, unit, error, dtype):
+ def _strip_units(mapping):
+ return {key: array_strip_units(value) for key, value in mapping.items()}
+
+ original_unit = unit_registry.m
+ array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit
+ array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": np.arange(len(array1))},
+ )
+
+ condition = ds < 0.5 * original_unit
+ other = np.linspace(-2, -1, 10).astype(dtype) * unit
+ variant_kwargs = {
+ "masking": {"cond": condition},
+ "replacing_scalar": {"cond": condition, "other": -1 * unit},
+ "replacing_array": {"cond": condition, "other": other},
+ "dropping": {"cond": condition, "drop": True},
+ }
+ kwargs = variant_kwargs.get(variant)
+ kwargs_without_units = _strip_units(kwargs)
+
+ if variant not in ("masking", "dropping") and error is not None:
+ with pytest.raises(error):
+ ds.where(**kwargs)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).where(**kwargs_without_units),
+ {"a": original_unit, "b": original_unit},
+ )
+ result = ds.where(**kwargs)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="interpolate strips units")
+ def test_interpolate_na(self, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.degK
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.Pa
+ )
+ x = np.arange(len(array1))
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ expected = attach_units(
+ strip_units(ds).interpolate_na(dim="x"),
+ {"a": unit_registry.degK, "b": unit_registry.Pa},
+ )
+ result = ds.interpolate_na(dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="uses Dataset.where, which currently fails")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="same_unit"),
+ ),
+ )
+ def test_combine_first(self, unit, error, dtype):
+ array1 = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.degK
+ )
+ array2 = (
+ np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
+ * unit_registry.Pa
+ )
+ x = np.arange(len(array1))
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+ other_array1 = np.ones_like(array1) * unit
+ other_array2 = -1 * np.ones_like(array2) * unit
+ other = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=other_array1, dims="x"),
+ "b": xr.DataArray(data=other_array2, dims="x"),
+ },
+ coords={"x": np.arange(array1.shape[0])},
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.combine_first(other)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).combine_first(strip_units(other)),
+ {"a": unit_registry.m, "b": unit_registry.m},
+ )
+ result = ds.combine_first(other)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(
+ unit_registry.cm,
+ id="compatible_unit",
+ marks=pytest.mark.xfail(reason="identical does not check units yet"),
+ ),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "variation",
+ (
+ "data",
+ pytest.param(
+ "dims", marks=pytest.mark.xfail(reason="units in indexes not supported")
+ ),
+ "coords",
+ ),
+ )
+ @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
+ def test_comparisons(self, func, variation, unit, dtype):
+ array1 = np.linspace(0, 5, 10).astype(dtype)
+ array2 = np.linspace(-5, 0, 10).astype(dtype)
+
+ coord = np.arange(len(array1)).astype(dtype)
+
+ original_unit = unit_registry.m
+ quantity1 = array1 * original_unit
+ quantity2 = array2 * original_unit
+ x = coord * original_unit
+ y = coord * original_unit
+
+ units = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = units.get(variation)
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=quantity1, dims="x"),
+ "b": xr.DataArray(data=quantity2, dims="x"),
+ },
+ coords={"x": x, "y": ("x", y)},
+ )
+
+ other = attach_units(
+ strip_units(ds),
+ {
+ "a": (data_unit, original_unit if quantity1.check(data_unit) else None),
+ "b": (data_unit, original_unit if quantity2.check(data_unit) else None),
+ "x": (dim_unit, original_unit if x.check(dim_unit) else None),
+ "y": (coord_unit, original_unit if y.check(coord_unit) else None),
+ },
+ )
+
+ # TODO: test dim coord once indexes leave units intact
+ # also, express this in terms of calls on the raw data array
+ # and then check the units
+ equal_arrays = (
+ np.all(ds.a.data == other.a.data)
+ and np.all(ds.b.data == other.b.data)
+ and (np.all(x == other.x.data) or True) # dims can't be checked yet
+ and np.all(y == other.y.data)
+ )
+ equal_units = (
+ data_unit == original_unit
+ and coord_unit == original_unit
+ and dim_unit == original_unit
+ )
+ expected = equal_arrays and (func.name != "identical" or equal_units)
+ result = func(ds, other)
+
+ assert expected == result
+
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ def test_broadcast_equals(self, unit, dtype):
+ left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m
+ left_array2 = np.zeros(shape=(2, 6), dtype=dtype) * unit_registry.m
+
+ right_array1 = array_attach_units(
+ np.ones(shape=(2,), dtype=dtype),
+ unit,
+ convert_from=unit_registry.m if left_array1.check(unit) else None,
+ )
+ right_array2 = array_attach_units(
+ np.ones(shape=(2,), dtype=dtype),
+ unit,
+ convert_from=unit_registry.m if left_array2.check(unit) else None,
+ )
+
+ left = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=left_array1, dims=("x", "y")),
+ "b": xr.DataArray(data=left_array2, dims=("x", "z")),
+ }
+ )
+ right = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=right_array1, dims="x"),
+ "b": xr.DataArray(data=right_array2, dims="x"),
+ }
+ )
+
+ expected = np.all(left_array1 == right_array1[:, None]) and np.all(
+ left_array2 == right_array2[:, None]
+ )
+ result = left.broadcast_equals(right)
+
+ assert expected == result
+
+ @pytest.mark.parametrize(
+ "func",
+ (method("unstack"), method("reset_index", "v"), method("reorder_levels")),
+ ids=repr,
+ )
+ def test_stacking_stacked(self, func, dtype):
+ array1 = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
+ )
+ array2 = (
+ np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype)
+ * unit_registry.m
+ )
+
+ x = np.arange(array1.shape[0])
+ y = np.arange(array1.shape[1])
+ z = np.arange(array2.shape[2])
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
+ },
+ coords={"x": x, "y": y, "z": z},
+ )
+
+ stacked = ds.stack(v=("x", "y"))
+
+ expected = attach_units(
+ func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m}
+ )
+ result = func(stacked)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="tries to subscript scalar quantities")
+ def test_to_stacked_array(self, dtype):
+ labels = np.arange(5).astype(dtype) * unit_registry.s
+ arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels}
+
+ ds = xr.Dataset(
+ data_vars={
+ name: xr.DataArray(data=array, dims="x")
+ for name, array in arrays.items()
+ }
+ )
+
+ func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"])
+
+ result = func(ds).rename(None)
+ expected = attach_units(
+ func(strip_units(ds)).rename(None),
+ {None: unit_registry.m, "y": unit_registry.s},
+ )
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("transpose", "y", "x", "z1", "z2"),
+ method("stack", a=("x", "y")),
+ method("set_index", x="x2"),
+ pytest.param(
+ method("shift", x=2), marks=pytest.mark.xfail(reason="sets all to nan")
+ ),
+ pytest.param(
+ method("roll", x=2, roll_coords=False),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ method("sortby", "x2"),
+ ),
+ ids=repr,
+ )
+ def test_stacking_reordering(self, func, dtype):
+ array1 = (
+ np.linspace(0, 10, 2 * 5 * 10).reshape(2, 5, 10).astype(dtype)
+ * unit_registry.Pa
+ )
+ array2 = (
+ np.linspace(0, 10, 2 * 5 * 15).reshape(2, 5, 15).astype(dtype)
+ * unit_registry.degK
+ )
+
+ x = np.arange(array1.shape[0])
+ y = np.arange(array1.shape[1])
+ z1 = np.arange(array1.shape[2])
+ z2 = np.arange(array2.shape[2])
+
+ x2 = np.linspace(0, 1, array1.shape[0])[::-1]
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y", "z1")),
+ "b": xr.DataArray(data=array2, dims=("x", "y", "z2")),
+ },
+ coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)},
+ )
+
+ expected = attach_units(
+ func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK}
+ )
+ result = func(ds)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="indexes strip units")
+ @pytest.mark.parametrize(
+ "indices",
+ (
+ pytest.param(4, id="single index"),
+ pytest.param([5, 2, 9, 1], id="multiple indices"),
+ ),
+ )
+ def test_isel(self, indices, dtype):
+ array1 = np.arange(10).astype(dtype) * unit_registry.s
+ array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa
+
+ x = np.arange(len(array1)) * unit_registry.m
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ expected = attach_units(
+ strip_units(ds).isel(x=indices),
+ {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m},
+ )
+ result = ds.isel(x=indices)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(
+ reason="xarray does not support duck arrays in dimension coordinates"
+ )
+ @pytest.mark.parametrize(
+ "values",
+ (
+ pytest.param(12, id="single_value"),
+ pytest.param([10, 5, 13], id="list_of_values"),
+ pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "units,error",
+ (
+ pytest.param(1, KeyError, id="no_units"),
+ pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
+ pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"),
+ pytest.param(unit_registry.ms, KeyError, id="compatible_unit"),
+ pytest.param(unit_registry.s, None, id="same_unit"),
+ ),
+ )
+ def test_sel(self, values, units, error, dtype):
+ array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK
+ array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa
+ x = np.arange(len(array1)) * unit_registry.s
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ values_with_units = values * units
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.sel(x=values_with_units)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).sel(x=values),
+ {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s},
+ )
+ result = ds.sel(x=values_with_units)
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(
+ reason="xarray does not support duck arrays in dimension coordinates"
+ )
+ @pytest.mark.parametrize(
+ "values",
+ (
+ pytest.param(12, id="single value"),
+ pytest.param([10, 5, 13], id="list of multiple values"),
+ pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "units,error",
+ (
+ pytest.param(1, KeyError, id="no_units"),
+ pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
+ pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"),
+ pytest.param(unit_registry.ms, KeyError, id="compatible_unit"),
+ pytest.param(unit_registry.s, None, id="same_unit"),
+ ),
+ )
+ def test_loc(self, values, units, error, dtype):
+ array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK
+ array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa
+ x = np.arange(len(array1)) * unit_registry.s
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims="x"),
+ "b": xr.DataArray(data=array2, dims="x"),
+ },
+ coords={"x": x},
+ )
+
+ values_with_units = values * units
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.loc[{"x": values_with_units}]
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).loc[{"x": values}],
+ {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s},
+ )
+ result = ds.loc[{"x": values_with_units}]
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(
+ reason="indexes strip units and head / tail / thin only support integers"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("head", x=7, y=3, z=6),
+ method("tail", x=7, y=3, z=6),
+ method("thin", x=7, y=3, z=6),
+ ),
+ ids=repr,
+ )
+ def test_head_tail_thin(self, func, unit, error, dtype):
+ array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ "z": np.arange(8) * unit_registry.m,
+ }
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords=coords,
+ )
+
+ kwargs = {name: value * unit for name, value in func.kwargs.items()}
+
+ if error is not None:
+ with pytest.raises(error):
+ func(ds, **kwargs)
+
+ return
+
+ expected = attach_units(func(strip_units(ds)), extract_units(ds))
+ result = func(ds, **kwargs)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "shape",
+ (
+ pytest.param((10, 20), id="nothing squeezable"),
+ pytest.param((10, 20, 1), id="last dimension squeezable"),
+ pytest.param((10, 1, 20), id="middle dimension squeezable"),
+ pytest.param((1, 10, 20), id="first dimension squeezable"),
+ pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"),
+ ),
+ )
+ def test_squeeze(self, shape, dtype):
+ names = "xyzt"
+ coords = {
+ name: np.arange(length).astype(dtype)
+ * (unit_registry.m if name != "t" else unit_registry.s)
+ for name, length in zip(names, shape)
+ }
+ array1 = (
+ np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(1, 2, 10 * 20).astype(dtype).reshape(shape) * unit_registry.Pa
+ )
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=tuple(names[: len(shape)])),
+ "b": xr.DataArray(data=array2, dims=tuple(names[: len(shape)])),
+ },
+ coords=coords,
+ )
+ units = extract_units(ds)
+
+ expected = attach_units(strip_units(ds).squeeze(), units)
+
+ result = ds.squeeze()
+ assert_equal_with_units(result, expected)
+
+ # try squeezing the dimensions separately
+ names = tuple(dim for dim, coord in coords.items() if len(coord) == 1)
+ for name in names:
+ expected = attach_units(strip_units(ds).squeeze(dim=name), units)
+ result = ds.squeeze(dim=name)
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.xfail(reason="ignores units")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_interp(self, unit, error):
+ array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ "z": np.arange(8) * unit_registry.s,
+ }
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords=coords,
+ )
+
+ new_coords = (np.arange(10) + 0.5) * unit
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.interp(x=new_coords)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).interp(x=strip_units(new_coords)), extract_units(ds)
+ )
+ result = ds.interp(x=new_coords)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.xfail(reason="ignores units")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_interp_like(self, unit, error, dtype):
+ array1 = (
+ np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
+ )
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ "z": np.arange(8) * unit_registry.m,
+ }
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords=coords,
+ )
+
+ other = xr.Dataset(
+ data_vars={
+ "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")),
+ "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")),
+ },
+ coords={
+ "x": (np.arange(20) + 0.3) * unit,
+ "y": (np.arange(10) - 0.2) * unit,
+ "z": (np.arange(15) + 0.4) * unit,
+ },
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.interp_like(other)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).interp_like(strip_units(other)), extract_units(ds)
+ )
+ result = ds.interp_like(other)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.xfail(
+ reason="pint does not implement np.result_type in __array_function__ yet"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_reindex(self, unit, error):
+ array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ "z": np.arange(8) * unit_registry.s,
+ }
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords=coords,
+ )
+
+ new_coords = (np.arange(10) + 0.5) * unit
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.interp(x=new_coords)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).reindex(x=strip_units(new_coords)), extract_units(ds)
+ )
+ result = ds.reindex(x=new_coords)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.xfail(
+ reason="pint does not implement np.result_type in __array_function__ yet"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_reindex_like(self, unit, error, dtype):
+ array1 = (
+ np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
+ )
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ "z": np.arange(8) * unit_registry.m,
+ }
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords=coords,
+ )
+
+ other = xr.Dataset(
+ data_vars={
+ "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")),
+ "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")),
+ },
+ coords={
+ "x": (np.arange(20) + 0.3) * unit,
+ "y": (np.arange(10) - 0.2) * unit,
+ "z": (np.arange(15) + 0.4) * unit,
+ },
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ ds.reindex_like(other)
+
+ return
+
+ expected = attach_units(
+ strip_units(ds).reindex_like(strip_units(other)), extract_units(ds)
+ )
+ result = ds.reindex_like(other)
+
+ assert_equal_with_units(result, expected)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("diff", dim="x"),
+ method("differentiate", coord="x"),
+ method("integrate", coord="x"),
+ pytest.param(
+ method("quantile", q=[0.25, 0.75]),
+ marks=pytest.mark.xfail(
+ reason="pint does not implement nanpercentile yet"
+ ),
+ ),
+ pytest.param(
+ method("reduce", func=np.sum, dim="x"),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ pytest.param(
+ method("apply", np.fabs),
+ marks=pytest.mark.xfail(reason="fabs strips units"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_computation(self, func, dtype):
+ array1 = (
+ np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
+ )
+ x = np.arange(10) * unit_registry.m
+ y = np.arange(5) * unit_registry.m
+ z = np.arange(8) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "z")),
+ },
+ coords={"x": x, "y": y, "z": z},
+ )
+
+ units = extract_units(ds)
+
+ expected = attach_units(func(strip_units(ds)), units)
+ result = func(ds)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ method("groupby", "x"), marks=pytest.mark.xfail(reason="strips units")
+ ),
+ pytest.param(
+ method("groupby_bins", "x", bins=4),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ method("coarsen", x=2),
+ pytest.param(
+ method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units")
+ ),
+ pytest.param(
+ method("rolling_exp", x=3),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_computation_objects(self, func, dtype):
+ array1 = (
+ np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
+ * unit_registry.Pa
+ )
+ x = np.arange(10) * unit_registry.m
+ y = np.arange(5) * unit_registry.m
+ z = np.arange(8) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
+ },
+ coords={"x": x, "y": y, "z": z},
+ )
+ units = extract_units(ds)
+
+ args = [] if func.name != "groupby" else ["y"]
+ reduce_func = method("mean", *args)
+ expected = attach_units(reduce_func(func(strip_units(ds))), units)
+ result = reduce_func(func(ds))
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="strips units")
+ def test_resample(self, dtype):
+ array1 = (
+ np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
+ )
+ t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y")
+ y = np.arange(5) * unit_registry.m
+ z = np.arange(8) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("time", "y")),
+ "b": xr.DataArray(data=array2, dims=("time", "z")),
+ },
+ coords={"time": t, "y": y, "z": z},
+ )
+ units = extract_units(ds)
+
+ func = method("resample", time="6m")
+
+ expected = attach_units(func(strip_units(ds)).mean(), units)
+ result = func(ds).mean()
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ method("assign", c=lambda ds: 10 * ds.b),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ pytest.param(
+ method("assign_coords", v=("x", np.arange(10) * unit_registry.s)),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ pytest.param(method("first")),
+ pytest.param(method("last")),
+ pytest.param(
+ method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
+ marks=pytest.mark.xfail(
+ reason="dataset groupby does not implement quantile"
+ ),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_grouped_operations(self, func, dtype):
+ array1 = (
+ np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
+ * unit_registry.Pa
+ )
+ x = np.arange(10) * unit_registry.m
+ y = np.arange(5) * unit_registry.m
+ z = np.arange(8) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
+ },
+ coords={"x": x, "y": y, "z": z},
+ )
+ units = extract_units(ds)
+ units.update({"c": unit_registry.Pa, "v": unit_registry.s})
+
+ stripped_kwargs = {
+ name: strip_units(value) for name, value in func.kwargs.items()
+ }
+ expected = attach_units(
+ func(strip_units(ds).groupby("y"), **stripped_kwargs), units
+ )
+ result = func(ds.groupby("y"))
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("pipe", lambda ds: ds * 10),
+ method("assign", d=lambda ds: ds.b * 10),
+ method("assign_coords", y2=("y", np.arange(5) * unit_registry.mm)),
+ method("assign_attrs", attr1="value"),
+ method("rename", x2="x_mm"),
+ method("rename_vars", c="temperature"),
+ method("rename_dims", x="offset_x"),
+ method("swap_dims", {"x": "x2"}),
+ method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1),
+ method("drop", labels="x"),
+ method("drop_dims", "z"),
+ method("set_coords", names="c"),
+ method("reset_coords", names="x2"),
+ method("copy"),
+ ),
+ ids=repr,
+ )
+ def test_content_manipulation(self, func, dtype):
+ array1 = (
+ np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype)
+ * unit_registry.m ** 3
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
+ * unit_registry.Pa
+ )
+ array3 = np.linspace(0, 10, 10).astype(dtype) * unit_registry.degK
+
+ x = np.arange(10) * unit_registry.m
+ x2 = x.to(unit_registry.mm)
+ y = np.arange(5) * unit_registry.m
+ z = np.arange(8) * unit_registry.m
+
+ ds = xr.Dataset(
+ data_vars={
+ "a": xr.DataArray(data=array1, dims=("x", "y")),
+ "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
+ "c": xr.DataArray(data=array3, dims="x"),
+ },
+ coords={"x": x, "y": y, "z": z, "x2": ("x", x2)},
+ )
+ units = extract_units(ds)
+ units.update(
+ {
+ "y2": unit_registry.mm,
+ "x_mm": unit_registry.mm,
+ "offset_x": unit_registry.m,
+ "d": unit_registry.Pa,
+ "temperature": unit_registry.degK,
+ }
+ )
+
+ stripped_kwargs = {
+ key: strip_units(value) for key, value in func.kwargs.items()
+ }
+ expected = attach_units(func(strip_units(ds), **stripped_kwargs), units)
+ result = func(ds)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="blocked by reindex")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, xr.MergeError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, xr.MergeError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, xr.MergeError, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize("variant", ("data", "dims", "coords"))
+ def test_merge(self, variant, unit, error, dtype):
+ original_data_unit = unit_registry.m
+ original_dim_unit = unit_registry.m
+ original_coord_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_dim_unit, original_coord_unit),
+ "dims": (original_data_unit, unit, original_coord_unit),
+ "coords": (original_data_unit, original_dim_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ left_array = np.arange(10).astype(dtype) * original_data_unit
+ right_array = np.arange(-5, 5).astype(dtype) * data_unit
+
+ left_dim = np.arange(10, 20) * original_dim_unit
+ right_dim = np.arange(5, 15) * dim_unit
+
+ left_coord = np.arange(-10, 0) * original_coord_unit
+ right_coord = np.arange(-15, -5) * coord_unit
+
+ left = xr.Dataset(
+ data_vars={"a": ("x", left_array)},
+ coords={"x": left_dim, "y": ("x", left_coord)},
+ )
+ right = xr.Dataset(
+ data_vars={"a": ("x", right_array)},
+ coords={"x": right_dim, "y": ("x", right_coord)},
+ )
+
+ units = extract_units(left)
+
+ if error is not None:
+ with pytest.raises(error):
+ left.merge(right)
+
+ return
+
+ converted = convert_units(right, units)
+ expected = attach_units(strip_units(left).merge(strip_units(converted)), units)
+ result = left.merge(right)
+
+ assert_equal_with_units(expected, result)