diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index f840557ab5d..a7687368884 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -118,7 +118,8 @@ Internal Changes
~~~~~~~~~~~~~~~~
- Added integration tests against `pint `_.
- (:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin `_.
+ (:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`)
+ by `Justus Magin `_.
.. note::
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index fd9e9b039ac..509a50d23ff 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -222,7 +222,9 @@ def convert_units(obj, to):
if name != obj.name
}
- new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs)
+ new_obj = xr.DataArray(
+ name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims
+ )
elif isinstance(obj, unit_registry.Quantity):
units = to.get(None)
new_obj = obj.to(units) if units is not None else obj
@@ -307,19 +309,689 @@ def __repr__(self):
class function:
- def __init__(self, name):
- self.name = name
- self.func = getattr(np, name)
+ def __init__(self, name_or_function, *args, **kwargs):
+ if callable(name_or_function):
+ self.name = name_or_function.__name__
+ self.func = name_or_function
+ else:
+ self.name = name_or_function
+ self.func = getattr(np, name_or_function)
+ if self.func is None:
+ raise AttributeError(
+ f"module 'numpy' has no attribute named '{self.name}'"
+ )
+
+ self.args = args
+ self.kwargs = kwargs
def __call__(self, *args, **kwargs):
- return self.func(*args, **kwargs)
+ all_args = list(self.args) + list(args)
+ all_kwargs = {**self.kwargs, **kwargs}
+
+ return self.func(*all_args, **all_kwargs)
def __repr__(self):
return f"function_{self.name}"
+def test_apply_ufunc_dataarray(dtype):
+ func = function(
+ xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1}
+ )
+
+ array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.m
+ x = np.arange(20) * unit_registry.s
+ data_array = xr.DataArray(data=array, dims="x", coords={"x": x})
+
+ expected = attach_units(func(strip_units(data_array)), extract_units(data_array))
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(
+ reason="pint does not implement `np.result_type` and align strips units"
+)
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan)))
+def test_align_dataarray(fill_value, variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
+ array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit
+ x = np.arange(2) * original_unit
+ x_a1 = np.array([10, 5]) * original_unit
+ x_a2 = np.array([10, 5]) * coord_unit
+
+ y1 = np.arange(5) * original_unit
+ y2 = np.arange(2, 7) * dim_unit
+
+ data_array1 = xr.DataArray(
+ data=array1, coords={"x": x, "x_a": ("x", x_a1), "y": y1}, dims=("x", "y")
+ )
+ data_array2 = xr.DataArray(
+ data=array2, coords={"x": x, "x_a": ("x", x_a2), "y": y2}, dims=("x", "y")
+ )
+
+ fill_value = fill_value * data_unit
+ func = function(xr.align, join="outer", fill_value=fill_value)
+ if error is not None:
+ with pytest.raises(error):
+ func(data_array1, data_array2)
+
+ return
+
+ stripped_kwargs = {
+ key: strip_units(
+ convert_units(value, {None: original_unit})
+ if isinstance(value, unit_registry.Quantity)
+ else value
+ )
+ for key, value in func.kwargs.items()
+ }
+ units = extract_units(data_array1)
+ # FIXME: should the expected_b have the same units as data_array1
+ # or data_array2?
+ expected_a, expected_b = tuple(
+ attach_units(elem, units)
+ for elem in func(
+ strip_units(data_array1),
+ strip_units(convert_units(data_array2, units)),
+ **stripped_kwargs,
+ )
+ )
+ result_a, result_b = func(data_array1, data_array2)
+
+ assert_equal_with_units(expected_a, result_a)
+ assert_equal_with_units(expected_b, result_b)
+
+
+@pytest.mark.xfail(
+ reason="pint does not implement `np.result_type` and align strips units"
+)
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan)))
+def test_align_dataset(fill_value, unit, variant, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
+ array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit
+
+ x = np.arange(2) * original_unit
+ x_a1 = np.array([10, 5]) * original_unit
+ x_a2 = np.array([10, 5]) * coord_unit
+
+ y1 = np.arange(5) * original_unit
+ y2 = np.arange(2, 7) * dim_unit
+
+ ds1 = xr.Dataset(
+ data_vars={"a": (("x", "y"), array1)},
+ coords={"x": x, "x_a": ("x", x_a1), "y": y1},
+ )
+ ds2 = xr.Dataset(
+ data_vars={"a": (("x", "y"), array2)},
+ coords={"x": x, "x_a": ("x", x_a2), "y": y2},
+ )
+
+ fill_value = fill_value * data_unit
+ func = function(xr.align, join="outer", fill_value=fill_value)
+ if error is not None:
+ with pytest.raises(error):
+ func(ds1, ds2)
+
+ return
+
+ stripped_kwargs = {
+ key: strip_units(
+ convert_units(value, {None: original_unit})
+ if isinstance(value, unit_registry.Quantity)
+ else value
+ )
+ for key, value in func.kwargs.items()
+ }
+ units = extract_units(ds1)
+ # FIXME: should the expected_b have the same units as ds1 or ds2?
+ expected_a, expected_b = tuple(
+ attach_units(elem, units)
+ for elem in func(
+ strip_units(ds1), strip_units(convert_units(ds2, units)), **stripped_kwargs
+ )
+ )
+ result_a, result_b = func(ds1, ds2)
+
+ assert_equal_with_units(expected_a, result_a)
+ assert_equal_with_units(expected_b, result_b)
+
+
+def test_broadcast_dataarray(dtype):
+ array1 = np.linspace(0, 10, 2) * unit_registry.Pa
+ array2 = np.linspace(0, 10, 3) * unit_registry.Pa
+
+ a = xr.DataArray(data=array1, dims="x")
+ b = xr.DataArray(data=array2, dims="y")
+
+ expected_a, expected_b = tuple(
+ attach_units(elem, extract_units(a))
+ for elem in xr.broadcast(strip_units(a), strip_units(b))
+ )
+ result_a, result_b = xr.broadcast(a, b)
+
+ assert_equal_with_units(expected_a, result_a)
+ assert_equal_with_units(expected_b, result_b)
+
+
+def test_broadcast_dataset(dtype):
+ array1 = np.linspace(0, 10, 2) * unit_registry.Pa
+ array2 = np.linspace(0, 10, 3) * unit_registry.Pa
+
+ ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("y", array2)})
+
+ (expected,) = tuple(
+ attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds))
+ )
+ (result,) = xr.broadcast(ds)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="`combine_by_coords` strips units")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+def test_combine_by_coords(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+ array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+ x = np.arange(1, 4) * 10 * original_unit
+ y = np.arange(2) * original_unit
+ z = np.arange(3) * original_unit
+
+ other_array1 = np.ones_like(array1) * data_unit
+ other_array2 = np.ones_like(array2) * data_unit
+ other_x = np.arange(1, 4) * 10 * dim_unit
+ other_y = np.arange(2, 4) * dim_unit
+ other_z = np.arange(3, 6) * coord_unit
+
+ ds = xr.Dataset(
+ data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
+ coords={"x": x, "y": y, "z": ("x", z)},
+ )
+ other = xr.Dataset(
+ data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)},
+ coords={"x": other_x, "y": other_y, "z": ("x", other_z)},
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ xr.combine_by_coords([ds, other])
+
+ return
+
+ units = extract_units(ds)
+ expected = attach_units(
+ xr.combine_by_coords(
+ [strip_units(ds), strip_units(convert_units(other, units))]
+ ),
+ units,
+ )
+ result = xr.combine_by_coords([ds, other])
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="blocked by `where`")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+def test_combine_nested(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+ array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+
+ x = np.arange(1, 4) * 10 * original_unit
+ y = np.arange(2) * original_unit
+ z = np.arange(3) * original_unit
+
+ ds1 = xr.Dataset(
+ data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
+ coords={"x": x, "y": y, "z": ("x", z)},
+ )
+ ds2 = xr.Dataset(
+ data_vars={
+ "a": (("y", "x"), np.ones_like(array1) * data_unit),
+ "b": (("y", "x"), np.ones_like(array2) * data_unit),
+ },
+ coords={
+ "x": np.arange(3) * dim_unit,
+ "y": np.arange(2, 4) * dim_unit,
+ "z": ("x", np.arange(-3, 0) * coord_unit),
+ },
+ )
+ ds3 = xr.Dataset(
+ data_vars={
+ "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit),
+ "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit),
+ },
+ coords={
+ "x": np.arange(3, 6) * dim_unit,
+ "y": np.arange(4, 6) * dim_unit,
+ "z": ("x", np.arange(3, 6) * coord_unit),
+ },
+ )
+ ds4 = xr.Dataset(
+ data_vars={
+ "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit),
+ "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit),
+ },
+ coords={
+ "x": np.arange(6, 9) * dim_unit,
+ "y": np.arange(6, 8) * dim_unit,
+ "z": ("x", np.arange(6, 9) * coord_unit),
+ },
+ )
+
+ func = function(xr.combine_nested, concat_dim=["x", "y"])
+ if error is not None:
+ with pytest.raises(error):
+ func([[ds1, ds2], [ds3, ds4]])
+
+ return
+
+ units = extract_units(ds1)
+ convert_and_strip = lambda ds: strip_units(convert_units(ds, units))
+ expected = attach_units(
+ func(
+ [
+ [strip_units(ds1), convert_and_strip(ds2)],
+ [convert_and_strip(ds3), convert_and_strip(ds4)],
+ ]
+ ),
+ units,
+ )
+ result = func([[ds1, ds2], [ds3, ds4]])
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="`concat` strips units")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ ),
+)
+def test_concat_dataarray(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
+ data_unit, dims_unit = variants.get(variant)
+
+ array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
+ array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
+ x1 = np.arange(5, 15) * original_unit
+ x2 = np.arange(5) * dims_unit
+
+ arr1 = xr.DataArray(data=array1, coords={"x": x1}, dims="x")
+ arr2 = xr.DataArray(data=array2, coords={"x": x2}, dims="x")
+
+ if error is not None:
+ with pytest.raises(error):
+ xr.concat([arr1, arr2], dim="x")
+
+ return
+
+ expected = attach_units(
+ xr.concat([strip_units(arr1), strip_units(arr2)], dim="x"), extract_units(arr1)
+ )
+ result = xr.concat([arr1, arr2], dim="x")
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="`concat` strips units")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ ),
+)
+def test_concat_dataset(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
+ data_unit, dims_unit = variants.get(variant)
+
+ array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
+ array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
+ x1 = np.arange(5, 15) * original_unit
+ x2 = np.arange(5) * dims_unit
+
+ ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1})
+ ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2})
+
+ if error is not None:
+ with pytest.raises(error):
+ xr.concat([ds1, ds2], dim="x")
+
+ return
+
+ expected = attach_units(
+ xr.concat([strip_units(ds1), strip_units(ds2)], dim="x"), extract_units(ds1)
+ )
+ result = xr.concat([ds1, ds2], dim="x")
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="blocked by `where`")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+def test_merge_dataarray(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * original_unit
+ array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit
+ array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit
+
+ x = np.arange(2) * original_unit
+ y = np.arange(3) * original_unit
+ z = np.arange(4) * original_unit
+ u = np.linspace(10, 20, 2) * original_unit
+ v = np.linspace(10, 20, 3) * original_unit
+ w = np.linspace(10, 20, 4) * original_unit
+
+ arr1 = xr.DataArray(
+ name="a",
+ data=array1,
+ coords={"x": x, "y": y, "u": ("x", u), "v": ("y", v)},
+ dims=("x", "y"),
+ )
+ arr2 = xr.DataArray(
+ name="b",
+ data=array2,
+ coords={
+ "x": np.arange(2, 4) * dim_unit,
+ "z": z,
+ "u": ("x", np.linspace(20, 30, 2) * coord_unit),
+ "w": ("z", w),
+ },
+ dims=("x", "z"),
+ )
+ arr3 = xr.DataArray(
+ name="c",
+ data=array3,
+ coords={
+ "y": np.arange(3, 6) * dim_unit,
+ "z": np.arange(4, 8) * dim_unit,
+ "v": ("y", np.linspace(10, 20, 3) * coord_unit),
+ "w": ("z", np.linspace(10, 20, 4) * coord_unit),
+ },
+ dims=("y", "z"),
+ )
+
+ func = function(xr.merge)
+ if error is not None:
+ with pytest.raises(error):
+ func([arr1, arr2, arr3])
+
+ return
+
+ units = {name: original_unit for name in list("abcuvwxyz")}
+ convert_and_strip = lambda arr: strip_units(convert_units(arr, units))
+ expected = attach_units(
+ func([strip_units(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]),
+ units,
+ )
+ result = func([arr1, arr2, arr3])
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="blocked by `where`")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize(
+ "variant",
+ (
+ "data",
+ pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
+ "coords",
+ ),
+)
+def test_merge_dataset(variant, unit, error, dtype):
+ original_unit = unit_registry.m
+
+ variants = {
+ "data": (unit, original_unit, original_unit),
+ "dims": (original_unit, unit, original_unit),
+ "coords": (original_unit, original_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = variants.get(variant)
+
+ array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+ array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
+
+ x = np.arange(11, 14) * original_unit
+ y = np.arange(2) * original_unit
+ z = np.arange(3) * original_unit
+
+ ds1 = xr.Dataset(
+ data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
+ coords={"x": x, "y": y, "z": ("x", z)},
+ )
+ ds2 = xr.Dataset(
+ data_vars={
+ "a": (("y", "x"), np.ones_like(array1) * data_unit),
+ "b": (("y", "x"), np.ones_like(array2) * data_unit),
+ },
+ coords={
+ "x": np.arange(3) * dim_unit,
+ "y": np.arange(2, 4) * dim_unit,
+ "z": ("x", np.arange(-3, 0) * coord_unit),
+ },
+ )
+ ds3 = xr.Dataset(
+ data_vars={
+ "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit),
+ "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit),
+ },
+ coords={
+ "x": np.arange(3, 6) * dim_unit,
+ "y": np.arange(4, 6) * dim_unit,
+ "z": ("x", np.arange(3, 6) * coord_unit),
+ },
+ )
+
+ func = function(xr.merge)
+ if error is not None:
+ with pytest.raises(error):
+ func([ds1, ds2, ds3])
+
+ return
+
+ units = extract_units(ds1)
+ convert_and_strip = lambda ds: strip_units(convert_units(ds, units))
+ expected = attach_units(
+ func([strip_units(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), units
+ )
+ result = func([ds1, ds2, ds3])
+
+ assert_equal_with_units(expected, result)
+
+
@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
-def test_replication(func, dtype):
+def test_replication_dataarray(func, dtype):
array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
data_array = xr.DataArray(data=array, dims="x")
@@ -330,8 +1002,33 @@ def test_replication(func, dtype):
assert_equal_with_units(expected, result)
+@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
+def test_replication_dataset(func, dtype):
+ array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
+ array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
+ x = np.arange(20).astype(dtype) * unit_registry.m
+ y = np.arange(10).astype(dtype) * unit_registry.m
+ z = y.to(unit_registry.mm)
+
+ ds = xr.Dataset(
+ data_vars={"a": ("x", array1), "b": ("y", array2)},
+ coords={"x": x, "y": y, "z": ("y", z)},
+ )
+
+ numpy_func = getattr(np, func.__name__)
+ expected = ds.copy(
+ data={name: numpy_func(array.data) for name, array in ds.data_vars.items()}
+ )
+ result = func(ds)
+
+ assert_equal_with_units(expected, result)
+
+
@pytest.mark.xfail(
- reason="np.full_like on Variable strips the unit and pint does not allow mixed args"
+ reason=(
+ "pint is undecided on how `full_like` should work, so incorrect errors "
+ "may be expected: hgrecco/pint#882"
+ )
)
@pytest.mark.parametrize(
"unit,error",
@@ -344,8 +1041,9 @@ def test_replication(func, dtype):
pytest.param(unit_registry.ms, None, id="compatible_unit"),
pytest.param(unit_registry.s, None, id="identical_unit"),
),
+ ids=repr,
)
-def test_replication_full_like(unit, error, dtype):
+def test_replication_full_like_dataarray(unit, error, dtype):
array = np.linspace(0, 5, 10) * unit_registry.s
data_array = xr.DataArray(data=array, dims="x")
@@ -360,6 +1058,163 @@ def test_replication_full_like(unit, error, dtype):
assert_equal_with_units(expected, result)
+@pytest.mark.xfail(
+ reason=(
+ "pint is undecided on how `full_like` should work, so incorrect errors "
+ "may be expected: hgrecco/pint#882"
+ )
+)
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.ms, None, id="compatible_unit"),
+ pytest.param(unit_registry.s, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+def test_replication_full_like_dataset(unit, error, dtype):
+ array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
+ array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
+ x = np.arange(20).astype(dtype) * unit_registry.m
+ y = np.arange(10).astype(dtype) * unit_registry.m
+ z = y.to(unit_registry.mm)
+
+ ds = xr.Dataset(
+ data_vars={"a": ("x", array1), "b": ("y", array2)},
+ coords={"x": x, "y": y, "z": ("y", z)},
+ )
+
+ fill_value = -1 * unit
+ if error is not None:
+ with pytest.raises(error):
+ xr.full_like(ds, fill_value=fill_value)
+
+ return
+
+ expected = ds.copy(
+ data={
+ name: np.full_like(array, fill_value=fill_value)
+ for name, array in ds.data_vars.items()
+ }
+ )
+ result = xr.full_like(ds, fill_value=fill_value)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="`where` strips units")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize("fill_value", (np.nan, 10.2))
+def test_where_dataarray(fill_value, unit, error, dtype):
+ array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
+
+ x = xr.DataArray(data=array, dims="x")
+ cond = x < 5 * unit_registry.m
+ # FIXME: this should work without wrapping in array()
+ fill_value = np.array(fill_value) * unit
+
+ if error is not None:
+ with pytest.raises(error):
+ xr.where(cond, x, fill_value)
+
+ return
+
+ fill_value_ = (
+ fill_value.to(unit_registry.m)
+ if isinstance(fill_value, unit_registry.Quantity)
+ and fill_value.check(unit_registry.m)
+ else fill_value
+ )
+ expected = attach_units(
+ xr.where(cond, strip_units(x), strip_units(fill_value_)), extract_units(x)
+ )
+ result = xr.where(cond, x, fill_value)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="`where` strips units")
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.mm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ ids=repr,
+)
+@pytest.mark.parametrize("fill_value", (np.nan, 10.2))
+def test_where_dataset(fill_value, unit, error, dtype):
+ array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
+ array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m
+ x = np.arange(10) * unit_registry.s
+
+ ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x})
+ cond = ds.x < 5 * unit_registry.s
+ # FIXME: this should work without wrapping in array()
+ fill_value = np.array(fill_value) * unit
+
+ if error is not None:
+ with pytest.raises(error):
+ xr.where(cond, ds, fill_value)
+
+ return
+
+ fill_value_ = (
+ fill_value.to(unit_registry.m)
+ if isinstance(fill_value, unit_registry.Quantity)
+ and fill_value.check(unit_registry.m)
+ else fill_value
+ )
+ expected = attach_units(
+ xr.where(cond, strip_units(ds), strip_units(fill_value_)), extract_units(ds)
+ )
+ result = xr.where(cond, ds, fill_value)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(reason="pint does not implement `np.einsum`")
+def test_dot_dataarray(dtype):
+ array1 = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
+ * unit_registry.m
+ / unit_registry.s
+ )
+ array2 = (
+ np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s
+ )
+
+ arr1 = xr.DataArray(data=array1, dims=("x", "y"))
+ arr2 = xr.DataArray(data=array2, dims=("y", "z"))
+
+ expected = array1.dot(array2)
+ result = xr.dot(arr1, arr2)
+
+ assert_equal_with_units(expected, result)
+
+
class TestDataArray:
@pytest.mark.filterwarnings("error:::pint[.*]")
@pytest.mark.parametrize(