diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0906058469d..4423e7e4255 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -104,7 +104,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Added integration tests against `pint `_. - (:pull:`3238`) by `Justus Magin `_. + (:pull:`3238`, :pull:`3447`) by `Justus Magin `_. .. note:: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 80063f8b4bc..8eed1f0dbe3 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -123,14 +123,19 @@ def extract_units(obj): def strip_units(obj): if isinstance(obj, xr.Dataset): - data_vars = {name: strip_units(value) for name, value in obj.data_vars.items()} - coords = {name: strip_units(value) for name, value in obj.coords.items()} + data_vars = { + strip_units(name): strip_units(value) + for name, value in obj.data_vars.items() + } + coords = { + strip_units(name): strip_units(value) for name, value in obj.coords.items() + } new_obj = xr.Dataset(data_vars=data_vars, coords=coords) elif isinstance(obj, xr.DataArray): data = array_strip_units(obj.data) coords = { - name: ( + strip_units(name): ( (value.dims, array_strip_units(value.data)) if isinstance(value.data, Quantity) else value # to preserve multiindexes @@ -138,9 +143,13 @@ def strip_units(obj): for name, value in obj.coords.items() } - new_obj = xr.DataArray(name=obj.name, data=data, coords=coords, dims=obj.dims) - elif hasattr(obj, "magnitude"): + new_obj = xr.DataArray( + name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims + ) + elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude + elif isinstance(obj, (list, tuple)): + return type(obj)(strip_units(elem) for elem in obj) else: new_obj = obj @@ -191,6 +200,38 @@ def attach_units(obj, units): return new_obj +def convert_units(obj, to): + if isinstance(obj, xr.Dataset): + data_vars = { + name: convert_units(array, to) for name, array in obj.data_vars.items() + } + coords = {name: convert_units(array, to) for name, array in obj.coords.items()} + + new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + elif isinstance(obj, xr.DataArray): + name = obj.name + + new_units = ( + to.get(name, None) or to.get("data", None) or to.get(None, None) or 1 + ) + data = convert_units(obj.data, {None: new_units}) + + coords = { + name: (array.dims, convert_units(array.data, to)) + for name, array in obj.coords.items() + if name != obj.name + } + + new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs) + elif isinstance(obj, unit_registry.Quantity): + units = to.get(None) + new_obj = obj.to(units) if units is not None else obj + else: + new_obj = obj + + return new_obj + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -1632,3 +1673,1696 @@ def test_grouped_operations(self, func, dtype): result = func(data_array.groupby("y")) assert_equal_with_units(expected, result) + + +class TestDataset: + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + @pytest.mark.parametrize( + "shared", + ( + "nothing", + pytest.param( + "dims", + marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), + ), + pytest.param( + "coords", + marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), + ), + ), + ) + def test_init(self, shared, unit, error, dtype): + original_unit = unit_registry.m + scaled_unit = unit_registry.mm + + a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa + b = np.linspace(-1, 0, 12).astype(dtype) * unit_registry.Pa + + raw_x = np.arange(a.shape[0]) + x = raw_x * original_unit + x2 = x.to(scaled_unit) + + raw_y = np.arange(b.shape[0]) + y = raw_y * unit + y_units = unit if isinstance(y, unit_registry.Quantity) else None + if isinstance(y, unit_registry.Quantity): + if y.check(scaled_unit): + y2 = y.to(scaled_unit) + else: + y2 = y * 1000 + y2_units = y2.units + else: + y2 = y * 1000 + y2_units = None + + variants = { + "nothing": ({"x": x, "x2": ("x", x2)}, {"y": y, "y2": ("y", y2)}), + "dims": ( + {"x": x, "x2": ("x", strip_units(x2))}, + {"x": y, "y2": ("x", strip_units(y2))}, + ), + "coords": ({"x": raw_x, "y": ("x", x2)}, {"x": raw_y, "y": ("x", y2)}), + } + coords_a, coords_b = variants.get(shared) + + dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x") + + arr1 = xr.DataArray(data=a, coords=coords_a, dims=dims_a) + arr2 = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + if error is not None and shared != "nothing": + with pytest.raises(error): + xr.Dataset(data_vars={"a": arr1, "b": arr2}) + + return + + result = xr.Dataset(data_vars={"a": arr1, "b": arr2}) + + expected_units = { + "a": a.units, + "b": b.units, + "x": x.units, + "x2": x2.units, + "y": y_units, + "y2": y2_units, + } + expected = attach_units( + xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}), + expected_units, + ) + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize( + "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) + ) + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "with_dims", + marks=pytest.mark.xfail(reason="units in indexes are not supported"), + ), + pytest.param("with_coords"), + pytest.param("without_coords"), + ), + ) + @pytest.mark.filterwarnings("error:::pint[.*]") + def test_repr(self, func, variant, dtype): + array1 = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.Pa + array2 = np.linspace(0, 1, 10, dtype=dtype) * unit_registry.degK + + x = np.arange(len(array1)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "with_dims": {"x": x}, + "with_coords": {"y": ("x", y)}, + "without_coords": {}, + } + + data_array = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("x", array2)}, + coords=variants.get(variant), + ) + + # FIXME: this just checks that the repr does not raise + # warnings or errors, but does not check the result + func(data_array) + + @pytest.mark.parametrize( + "func", + ( + pytest.param( + function("all"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + pytest.param( + function("any"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + function("argmax"), + function("argmin"), + function("max"), + function("min"), + function("mean"), + pytest.param( + function("median"), + marks=pytest.mark.xfail( + reason="np.median does not work with dataset yet" + ), + ), + pytest.param( + function("sum"), + marks=pytest.mark.xfail( + reason="np.result_type not implemented by pint" + ), + ), + pytest.param( + function("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + function("std"), + function("var"), + function("cumsum"), + pytest.param( + function("cumprod"), + marks=pytest.mark.xfail( + reason="pint does not support cumprod on non-dimensionless yet" + ), + ), + pytest.param( + method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") + ), + pytest.param( + method("any"), marks=pytest.mark.xfail(reason="not implemented by pint") + ), + method("argmax"), + method("argmin"), + method("max"), + method("min"), + method("mean"), + method("median"), + pytest.param( + method("sum"), + marks=pytest.mark.xfail( + reason="np.result_type not implemented by pint" + ), + ), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("var"), + method("cumsum"), + pytest.param( + method("cumprod"), + marks=pytest.mark.xfail( + reason="pint does not support cumprod on non-dimensionless yet" + ), + ), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + unit_a = unit_registry.Pa + unit_b = unit_registry.kg / unit_registry.m ** 3 + a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x") + b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x") + x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x") + y = xr.DataArray( + data=np.arange(10, 20).astype(dtype) * unit_registry.s, dims="x" + ) + + ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y}) + + result = func(ds) + expected = attach_units( + func(strip_units(ds)), + {"a": array_extract_units(func(a)), "b": array_extract_units(func(b))}, + ) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize("property", ("imag", "real")) + def test_numpy_properties(self, property, dtype): + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray( + data=np.linspace(0, 1, 10) * unit_registry.Pa, dims="x" + ), + "b": xr.DataArray( + data=np.linspace(-1, 0, 15) * unit_registry.Pa, dims="y" + ), + }, + coords={ + "x": np.arange(10) * unit_registry.m, + "y": np.arange(15) * unit_registry.s, + }, + ) + units = extract_units(ds) + + result = getattr(ds, property) + expected = attach_units(getattr(strip_units(ds), property), units) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize( + "func", + ( + method("astype", float), + method("conj"), + method("argsort"), + method("conjugate"), + method("round"), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="pint does not implement rank yet"), + ), + ), + ids=repr, + ) + def test_numpy_methods(self, func, dtype): + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray( + data=np.linspace(1, -1, 10) * unit_registry.Pa, dims="x" + ), + "b": xr.DataArray( + data=np.linspace(-1, 1, 15) * unit_registry.Pa, dims="y" + ), + }, + coords={ + "x": np.arange(10) * unit_registry.m, + "y": np.arange(15) * unit_registry.s, + }, + ) + units = { + "a": array_extract_units(func(ds.a)), + "b": array_extract_units(func(ds.b)), + "x": unit_registry.m, + "y": unit_registry.s, + } + + result = func(ds) + expected = attach_units(func(strip_units(ds)), units) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods_with_args(self, func, unit, error, dtype): + data_unit = unit_registry.m + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=np.arange(10) * data_unit, dims="x"), + "b": xr.DataArray(data=np.arange(15) * data_unit, dims="y"), + }, + coords={ + "x": np.arange(10) * unit_registry.m, + "y": np.arange(15) * unit_registry.s, + }, + ) + units = extract_units(ds) + + def strip(value): + return ( + value.magnitude if isinstance(value, unit_registry.Quantity) else value + ) + + def convert(value, to): + if isinstance(value, unit_registry.Quantity) and value.check(to): + return value.to(to) + + return value + + scalar_types = (int, float) + kwargs = { + key: (value * unit if isinstance(value, scalar_types) else value) + for key, value in func.kwargs.items() + } + + stripped_kwargs = { + key: strip(convert(value, data_unit)) for key, value in kwargs.items() + } + + if error is not None: + with pytest.raises(error): + func(ds, **kwargs) + + return + + result = func(ds, **kwargs) + expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func, dtype): + array1 = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + array2 = ( + np.array( + [ + [np.nan, 5.7, 12.0, 7.2], + [np.nan, 12.4, np.nan, 4.2], + [9.8, np.nan, 4.6, 1.4], + [7.2, np.nan, 6.3, np.nan], + [8.4, 3.9, np.nan, np.nan], + ] + ) + * unit_registry.Pa + ) + + x = np.arange(array1.shape[0]) * unit_registry.m + y = np.arange(array1.shape[1]) * unit_registry.m + z = np.arange(array2.shape[0]) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("z", "x")), + }, + coords={"x": x, "y": y, "z": z}, + ) + + expected = func(strip_units(ds)) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="ffill and bfill lose the unit") + @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) + def test_missing_value_filling(self, func, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + + x = np.arange(len(array1)) + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + expected = attach_units( + func(strip_units(ds), dim="x"), + {"a": unit_registry.degK, "b": unit_registry.Pa}, + ) + result = func(ds, dim="x") + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="fillna drops the unit") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail(reason="blocked by the failing `where`"), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "fill_value", + ( + pytest.param( + -1, + id="python scalar", + marks=pytest.mark.xfail( + reason="python scalar cannot be converted using astype()" + ), + ), + pytest.param(np.array(-1), id="numpy scalar"), + pytest.param(np.array([-1]), id="numpy array"), + ), + ) + def test_fillna(self, fill_value, unit, error, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.m + ) + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + } + ) + + if error is not None: + with pytest.raises(error): + ds.fillna(value=fill_value * unit) + + return + + result = ds.fillna(value=fill_value * unit) + expected = attach_units( + strip_units(ds).fillna(value=fill_value), + {"a": unit_registry.m, "b": unit_registry.m}, + ) + + assert_equal_with_units(expected, result) + + def test_dropna(self, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + x = np.arange(len(array1)) + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + expected = attach_units( + strip_units(ds).dropna(dim="x"), + {"a": unit_registry.degK, "b": unit_registry.Pa}, + ) + result = ds.dropna(dim="x") + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="pint does not implement `numpy.isin`") + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="same_unit"), + ), + ) + def test_isin(self, unit, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.m + ) + x = np.arange(len(array1)) + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) + values = raw_values * unit + + if ( + isinstance(values, unit_registry.Quantity) + and values.check(unit_registry.m) + and unit != unit_registry.m + ): + raw_values = values.to(unit_registry.m).magnitude + + expected = strip_units(ds).isin(raw_values) + if not isinstance(values, unit_registry.Quantity) or not values.check( + unit_registry.m + ): + expected.a[:] = False + expected.b[:] = False + result = ds.isin(values) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "masking", + marks=pytest.mark.xfail( + reason="np.result_type not implemented by quantity" + ), + ), + pytest.param( + "replacing_scalar", + marks=pytest.mark.xfail( + reason="python scalar not convertible using astype" + ), + ), + pytest.param( + "replacing_array", + marks=pytest.mark.xfail( + reason="replacing using an array drops the units" + ), + ), + pytest.param( + "dropping", + marks=pytest.mark.xfail(reason="nan not compatible with quantity"), + ), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + def test_where(self, variant, unit, error, dtype): + def _strip_units(mapping): + return {key: array_strip_units(value) for key, value in mapping.items()} + + original_unit = unit_registry.m + array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit + array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": np.arange(len(array1))}, + ) + + condition = ds < 0.5 * original_unit + other = np.linspace(-2, -1, 10).astype(dtype) * unit + variant_kwargs = { + "masking": {"cond": condition}, + "replacing_scalar": {"cond": condition, "other": -1 * unit}, + "replacing_array": {"cond": condition, "other": other}, + "dropping": {"cond": condition, "drop": True}, + } + kwargs = variant_kwargs.get(variant) + kwargs_without_units = _strip_units(kwargs) + + if variant not in ("masking", "dropping") and error is not None: + with pytest.raises(error): + ds.where(**kwargs) + + return + + expected = attach_units( + strip_units(ds).where(**kwargs_without_units), + {"a": original_unit, "b": original_unit}, + ) + result = ds.where(**kwargs) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="interpolate strips units") + def test_interpolate_na(self, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + x = np.arange(len(array1)) + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + expected = attach_units( + strip_units(ds).interpolate_na(dim="x"), + {"a": unit_registry.degK, "b": unit_registry.Pa}, + ) + result = ds.interpolate_na(dim="x") + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="uses Dataset.where, which currently fails") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + def test_combine_first(self, unit, error, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + x = np.arange(len(array1)) + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + other_array1 = np.ones_like(array1) * unit + other_array2 = -1 * np.ones_like(array2) * unit + other = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=other_array1, dims="x"), + "b": xr.DataArray(data=other_array2, dims="x"), + }, + coords={"x": np.arange(array1.shape[0])}, + ) + + if error is not None: + with pytest.raises(error): + ds.combine_first(other) + + return + + expected = attach_units( + strip_units(ds).combine_first(strip_units(other)), + {"a": unit_registry.m, "b": unit_registry.m}, + ) + result = ds.combine_first(other) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + marks=pytest.mark.xfail(reason="identical does not check units yet"), + ), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variation", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="units in indexes not supported") + ), + "coords", + ), + ) + @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) + def test_comparisons(self, func, variation, unit, dtype): + array1 = np.linspace(0, 5, 10).astype(dtype) + array2 = np.linspace(-5, 0, 10).astype(dtype) + + coord = np.arange(len(array1)).astype(dtype) + + original_unit = unit_registry.m + quantity1 = array1 * original_unit + quantity2 = array2 * original_unit + x = coord * original_unit + y = coord * original_unit + + units = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = units.get(variation) + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=quantity1, dims="x"), + "b": xr.DataArray(data=quantity2, dims="x"), + }, + coords={"x": x, "y": ("x", y)}, + ) + + other = attach_units( + strip_units(ds), + { + "a": (data_unit, original_unit if quantity1.check(data_unit) else None), + "b": (data_unit, original_unit if quantity2.check(data_unit) else None), + "x": (dim_unit, original_unit if x.check(dim_unit) else None), + "y": (coord_unit, original_unit if y.check(coord_unit) else None), + }, + ) + + # TODO: test dim coord once indexes leave units intact + # also, express this in terms of calls on the raw data array + # and then check the units + equal_arrays = ( + np.all(ds.a.data == other.a.data) + and np.all(ds.b.data == other.b.data) + and (np.all(x == other.x.data) or True) # dims can't be checked yet + and np.all(y == other.y.data) + ) + equal_units = ( + data_unit == original_unit + and coord_unit == original_unit + and dim_unit == original_unit + ) + expected = equal_arrays and (func.name != "identical" or equal_units) + result = func(ds, other) + + assert expected == result + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m + left_array2 = np.zeros(shape=(2, 6), dtype=dtype) * unit_registry.m + + right_array1 = array_attach_units( + np.ones(shape=(2,), dtype=dtype), + unit, + convert_from=unit_registry.m if left_array1.check(unit) else None, + ) + right_array2 = array_attach_units( + np.ones(shape=(2,), dtype=dtype), + unit, + convert_from=unit_registry.m if left_array2.check(unit) else None, + ) + + left = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=left_array1, dims=("x", "y")), + "b": xr.DataArray(data=left_array2, dims=("x", "z")), + } + ) + right = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=right_array1, dims="x"), + "b": xr.DataArray(data=right_array2, dims="x"), + } + ) + + expected = np.all(left_array1 == right_array1[:, None]) and np.all( + left_array2 == right_array2[:, None] + ) + result = left.broadcast_equals(right) + + assert expected == result + + @pytest.mark.parametrize( + "func", + (method("unstack"), method("reset_index", "v"), method("reorder_levels")), + ids=repr, + ) + def test_stacking_stacked(self, func, dtype): + array1 = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m + ) + array2 = ( + np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype) + * unit_registry.m + ) + + x = np.arange(array1.shape[0]) + y = np.arange(array1.shape[1]) + z = np.arange(array2.shape[2]) + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "y", "z")), + }, + coords={"x": x, "y": y, "z": z}, + ) + + stacked = ds.stack(v=("x", "y")) + + expected = attach_units( + func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m} + ) + result = func(stacked) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="tries to subscript scalar quantities") + def test_to_stacked_array(self, dtype): + labels = np.arange(5).astype(dtype) * unit_registry.s + arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} + + ds = xr.Dataset( + data_vars={ + name: xr.DataArray(data=array, dims="x") + for name, array in arrays.items() + } + ) + + func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) + + result = func(ds).rename(None) + expected = attach_units( + func(strip_units(ds)).rename(None), + {None: unit_registry.m, "y": unit_registry.s}, + ) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "func", + ( + method("transpose", "y", "x", "z1", "z2"), + method("stack", a=("x", "y")), + method("set_index", x="x2"), + pytest.param( + method("shift", x=2), marks=pytest.mark.xfail(reason="sets all to nan") + ), + pytest.param( + method("roll", x=2, roll_coords=False), + marks=pytest.mark.xfail(reason="strips units"), + ), + method("sortby", "x2"), + ), + ids=repr, + ) + def test_stacking_reordering(self, func, dtype): + array1 = ( + np.linspace(0, 10, 2 * 5 * 10).reshape(2, 5, 10).astype(dtype) + * unit_registry.Pa + ) + array2 = ( + np.linspace(0, 10, 2 * 5 * 15).reshape(2, 5, 15).astype(dtype) + * unit_registry.degK + ) + + x = np.arange(array1.shape[0]) + y = np.arange(array1.shape[1]) + z1 = np.arange(array1.shape[2]) + z2 = np.arange(array2.shape[2]) + + x2 = np.linspace(0, 1, array1.shape[0])[::-1] + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y", "z1")), + "b": xr.DataArray(data=array2, dims=("x", "y", "z2")), + }, + coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)}, + ) + + expected = attach_units( + func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK} + ) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="indexes strip units") + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array1 = np.arange(10).astype(dtype) * unit_registry.s + array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa + + x = np.arange(len(array1)) * unit_registry.m + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + expected = attach_units( + strip_units(ds).isel(x=indices), + {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m}, + ) + result = ds.isel(x=indices) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail( + reason="xarray does not support duck arrays in dimension coordinates" + ) + @pytest.mark.parametrize( + "values", + ( + pytest.param(12, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "units,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.ms, KeyError, id="compatible_unit"), + pytest.param(unit_registry.s, None, id="same_unit"), + ), + ) + def test_sel(self, values, units, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.s + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values_with_units = values * units + + if error is not None: + with pytest.raises(error): + ds.sel(x=values_with_units) + + return + + expected = attach_units( + strip_units(ds).sel(x=values), + {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s}, + ) + result = ds.sel(x=values_with_units) + assert_equal_with_units(expected, result) + + @pytest.mark.xfail( + reason="xarray does not support duck arrays in dimension coordinates" + ) + @pytest.mark.parametrize( + "values", + ( + pytest.param(12, id="single value"), + pytest.param([10, 5, 13], id="list of multiple values"), + pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"), + ), + ) + @pytest.mark.parametrize( + "units,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.ms, KeyError, id="compatible_unit"), + pytest.param(unit_registry.s, None, id="same_unit"), + ), + ) + def test_loc(self, values, units, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.s + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values_with_units = values * units + + if error is not None: + with pytest.raises(error): + ds.loc[{"x": values_with_units}] + + return + + expected = attach_units( + strip_units(ds).loc[{"x": values}], + {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s}, + ) + result = ds.loc[{"x": values_with_units}] + assert_equal_with_units(expected, result) + + @pytest.mark.xfail( + reason="indexes strip units and head / tail / thin only support integers" + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("head", x=7, y=3, z=6), + method("tail", x=7, y=3, z=6), + method("thin", x=7, y=3, z=6), + ), + ids=repr, + ) + def test_head_tail_thin(self, func, unit, error, dtype): + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + "z": np.arange(8) * unit_registry.m, + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + kwargs = {name: value * unit for name, value in func.kwargs.items()} + + if error is not None: + with pytest.raises(error): + func(ds, **kwargs) + + return + + expected = attach_units(func(strip_units(ds)), extract_units(ds)) + result = func(ds, **kwargs) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "shape", + ( + pytest.param((10, 20), id="nothing squeezable"), + pytest.param((10, 20, 1), id="last dimension squeezable"), + pytest.param((10, 1, 20), id="middle dimension squeezable"), + pytest.param((1, 10, 20), id="first dimension squeezable"), + pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"), + ), + ) + def test_squeeze(self, shape, dtype): + names = "xyzt" + coords = { + name: np.arange(length).astype(dtype) + * (unit_registry.m if name != "t" else unit_registry.s) + for name, length in zip(names, shape) + } + array1 = ( + np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK + ) + array2 = ( + np.linspace(1, 2, 10 * 20).astype(dtype).reshape(shape) * unit_registry.Pa + ) + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=tuple(names[: len(shape)])), + "b": xr.DataArray(data=array2, dims=tuple(names[: len(shape)])), + }, + coords=coords, + ) + units = extract_units(ds) + + expected = attach_units(strip_units(ds).squeeze(), units) + + result = ds.squeeze() + assert_equal_with_units(result, expected) + + # try squeezing the dimensions separately + names = tuple(dim for dim, coord in coords.items() if len(coord) == 1) + for name in names: + expected = attach_units(strip_units(ds).squeeze(dim=name), units) + result = ds.squeeze(dim=name) + assert_equal_with_units(result, expected) + + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_interp(self, unit, error): + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + "z": np.arange(8) * unit_registry.s, + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + new_coords = (np.arange(10) + 0.5) * unit + + if error is not None: + with pytest.raises(error): + ds.interp(x=new_coords) + + return + + expected = attach_units( + strip_units(ds).interp(x=strip_units(new_coords)), extract_units(ds) + ) + result = ds.interp(x=new_coords) + + assert_equal_with_units(result, expected) + + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_interp_like(self, unit, error, dtype): + array1 = ( + np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa + ) + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + "z": np.arange(8) * unit_registry.m, + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + other = xr.Dataset( + data_vars={ + "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), + "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), + }, + coords={ + "x": (np.arange(20) + 0.3) * unit, + "y": (np.arange(10) - 0.2) * unit, + "z": (np.arange(15) + 0.4) * unit, + }, + ) + + if error is not None: + with pytest.raises(error): + ds.interp_like(other) + + return + + expected = attach_units( + strip_units(ds).interp_like(strip_units(other)), extract_units(ds) + ) + result = ds.interp_like(other) + + assert_equal_with_units(result, expected) + + @pytest.mark.xfail( + reason="pint does not implement np.result_type in __array_function__ yet" + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_reindex(self, unit, error): + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + "z": np.arange(8) * unit_registry.s, + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + new_coords = (np.arange(10) + 0.5) * unit + + if error is not None: + with pytest.raises(error): + ds.interp(x=new_coords) + + return + + expected = attach_units( + strip_units(ds).reindex(x=strip_units(new_coords)), extract_units(ds) + ) + result = ds.reindex(x=new_coords) + + assert_equal_with_units(result, expected) + + @pytest.mark.xfail( + reason="pint does not implement np.result_type in __array_function__ yet" + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_reindex_like(self, unit, error, dtype): + array1 = ( + np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa + ) + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + "z": np.arange(8) * unit_registry.m, + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + other = xr.Dataset( + data_vars={ + "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), + "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), + }, + coords={ + "x": (np.arange(20) + 0.3) * unit, + "y": (np.arange(10) - 0.2) * unit, + "z": (np.arange(15) + 0.4) * unit, + }, + ) + + if error is not None: + with pytest.raises(error): + ds.reindex_like(other) + + return + + expected = attach_units( + strip_units(ds).reindex_like(strip_units(other)), extract_units(ds) + ) + result = ds.reindex_like(other) + + assert_equal_with_units(result, expected) + + @pytest.mark.parametrize( + "func", + ( + method("diff", dim="x"), + method("differentiate", coord="x"), + method("integrate", coord="x"), + pytest.param( + method("quantile", q=[0.25, 0.75]), + marks=pytest.mark.xfail( + reason="pint does not implement nanpercentile yet" + ), + ), + pytest.param( + method("reduce", func=np.sum, dim="x"), + marks=pytest.mark.xfail(reason="strips units"), + ), + pytest.param( + method("apply", np.fabs), + marks=pytest.mark.xfail(reason="fabs strips units"), + ), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + array1 = ( + np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa + ) + x = np.arange(10) * unit_registry.m + y = np.arange(5) * unit_registry.m + z = np.arange(8) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords={"x": x, "y": y, "z": z}, + ) + + units = extract_units(ds) + + expected = attach_units(func(strip_units(ds)), units) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("groupby", "x"), marks=pytest.mark.xfail(reason="strips units") + ), + pytest.param( + method("groupby_bins", "x", bins=4), + marks=pytest.mark.xfail(reason="strips units"), + ), + method("coarsen", x=2), + pytest.param( + method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units") + ), + pytest.param( + method("rolling_exp", x=3), + marks=pytest.mark.xfail(reason="strips units"), + ), + ), + ids=repr, + ) + def test_computation_objects(self, func, dtype): + array1 = ( + np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) + * unit_registry.Pa + ) + x = np.arange(10) * unit_registry.m + y = np.arange(5) * unit_registry.m + z = np.arange(8) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "y", "z")), + }, + coords={"x": x, "y": y, "z": z}, + ) + units = extract_units(ds) + + args = [] if func.name != "groupby" else ["y"] + reduce_func = method("mean", *args) + expected = attach_units(reduce_func(func(strip_units(ds))), units) + result = reduce_func(func(ds)) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="strips units") + def test_resample(self, dtype): + array1 = ( + np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa + ) + t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y") + y = np.arange(5) * unit_registry.m + z = np.arange(8) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("time", "y")), + "b": xr.DataArray(data=array2, dims=("time", "z")), + }, + coords={"time": t, "y": y, "z": z}, + ) + units = extract_units(ds) + + func = method("resample", time="6m") + + expected = attach_units(func(strip_units(ds)).mean(), units) + result = func(ds).mean() + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("assign", c=lambda ds: 10 * ds.b), + marks=pytest.mark.xfail(reason="strips units"), + ), + pytest.param( + method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), + marks=pytest.mark.xfail(reason="strips units"), + ), + pytest.param(method("first")), + pytest.param(method("last")), + pytest.param( + method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + marks=pytest.mark.xfail( + reason="dataset groupby does not implement quantile" + ), + ), + ), + ids=repr, + ) + def test_grouped_operations(self, func, dtype): + array1 = ( + np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) + * unit_registry.Pa + ) + x = np.arange(10) * unit_registry.m + y = np.arange(5) * unit_registry.m + z = np.arange(8) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "y", "z")), + }, + coords={"x": x, "y": y, "z": z}, + ) + units = extract_units(ds) + units.update({"c": unit_registry.Pa, "v": unit_registry.s}) + + stripped_kwargs = { + name: strip_units(value) for name, value in func.kwargs.items() + } + expected = attach_units( + func(strip_units(ds).groupby("y"), **stripped_kwargs), units + ) + result = func(ds.groupby("y")) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize( + "func", + ( + method("pipe", lambda ds: ds * 10), + method("assign", d=lambda ds: ds.b * 10), + method("assign_coords", y2=("y", np.arange(5) * unit_registry.mm)), + method("assign_attrs", attr1="value"), + method("rename", x2="x_mm"), + method("rename_vars", c="temperature"), + method("rename_dims", x="offset_x"), + method("swap_dims", {"x": "x2"}), + method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1), + method("drop", labels="x"), + method("drop_dims", "z"), + method("set_coords", names="c"), + method("reset_coords", names="x2"), + method("copy"), + ), + ids=repr, + ) + def test_content_manipulation(self, func, dtype): + array1 = ( + np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) + * unit_registry.m ** 3 + ) + array2 = ( + np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) + * unit_registry.Pa + ) + array3 = np.linspace(0, 10, 10).astype(dtype) * unit_registry.degK + + x = np.arange(10) * unit_registry.m + x2 = x.to(unit_registry.mm) + y = np.arange(5) * unit_registry.m + z = np.arange(8) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "y", "z")), + "c": xr.DataArray(data=array3, dims="x"), + }, + coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, + ) + units = extract_units(ds) + units.update( + { + "y2": unit_registry.mm, + "x_mm": unit_registry.mm, + "offset_x": unit_registry.m, + "d": unit_registry.Pa, + "temperature": unit_registry.degK, + } + ) + + stripped_kwargs = { + key: strip_units(value) for key, value in func.kwargs.items() + } + expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="blocked by reindex") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, xr.MergeError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, xr.MergeError, id="dimensionless" + ), + pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"), + pytest.param(unit_registry.cm, xr.MergeError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize("variant", ("data", "dims", "coords")) + def test_merge(self, variant, unit, error, dtype): + original_data_unit = unit_registry.m + original_dim_unit = unit_registry.m + original_coord_unit = unit_registry.m + + variants = { + "data": (unit, original_dim_unit, original_coord_unit), + "dims": (original_data_unit, unit, original_coord_unit), + "coords": (original_data_unit, original_dim_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + left_array = np.arange(10).astype(dtype) * original_data_unit + right_array = np.arange(-5, 5).astype(dtype) * data_unit + + left_dim = np.arange(10, 20) * original_dim_unit + right_dim = np.arange(5, 15) * dim_unit + + left_coord = np.arange(-10, 0) * original_coord_unit + right_coord = np.arange(-15, -5) * coord_unit + + left = xr.Dataset( + data_vars={"a": ("x", left_array)}, + coords={"x": left_dim, "y": ("x", left_coord)}, + ) + right = xr.Dataset( + data_vars={"a": ("x", right_array)}, + coords={"x": right_dim, "y": ("x", right_coord)}, + ) + + units = extract_units(left) + + if error is not None: + with pytest.raises(error): + left.merge(right) + + return + + converted = convert_units(right, units) + expected = attach_units(strip_units(left).merge(strip_units(converted)), units) + result = left.merge(right) + + assert_equal_with_units(expected, result)