diff --git a/docs/api.rst b/docs/api.rst index 4bcb96aa..e902c069 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -21,7 +21,7 @@ DataArray --------- .. autosummary:: :toctree: generated/ - :template: autosummary/accessor_property.rst + :template: autosummary/accessor_attribute.rst DataArray.pint.magnitude DataArray.pint.units diff --git a/pint_xarray/accessors.py b/pint_xarray/accessors.py index ed22e58a..e86fcd13 100644 --- a/pint_xarray/accessors.py +++ b/pint_xarray/accessors.py @@ -1,12 +1,12 @@ # TODO is it possible to import pint-xarray from within xarray if pint is present? -import numpy as np +import itertools + import pint from pint.quantity import Quantity from pint.unit import Unit from xarray import ( DataArray, Dataset, - Variable, register_dataarray_accessor, register_dataset_accessor, ) @@ -35,6 +35,40 @@ def is_dict_like(obj): return hasattr(obj, "keys") and hasattr(obj, "__getitem__") +def zip_mappings(*mappings, fill_value=None): + """ zip mappings by combining values for common keys into a tuple + + Works like itertools.zip_longest, so if a key is missing from a + mapping, it is replaced by ``fill_value``. + + Parameters + ---------- + *mappings : dict-like + The mappings to zip + fill_value + The value to use if a key is missing from a mapping. + + Returns + ------- + zipped : dict-like + The zipped mapping + """ + keys = set(itertools.chain.from_iterable(mapping.keys() for mapping in mappings)) + + # TODO: could this be made more efficient using itertools.groupby? + zipped = { + key: tuple(mapping.get(key, fill_value) for mapping in mappings) for key in keys + } + return zipped + + +def units_to_str_or_none(mapping): + return { + key: str(value) if isinstance(value, Unit) else value + for key, value in mapping.items() + } + + # based on xarray.core.utils.either_dict_or_kwargs # https://github.com/pydata/xarray/blob/v0.15.1/xarray/core/utils.py#L249-L268 def either_dict_or_kwargs(positional, keywords, method_name): @@ -53,49 +87,6 @@ def either_dict_or_kwargs(positional, keywords, method_name): return keywords -def _array_attach_units(data, unit, convert_from=None): - """ - Internal utility function for attaching units to a numpy-like array, - converting them, or throwing the correct error. - """ - - if isinstance(data, Quantity): - if not convert_from: - raise ValueError( - f"Cannot attach unit {unit} to quantity: data " - f"already has units {data.units}" - ) - elif isinstance(convert_from, Unit): - data = data.magnitude - elif convert_from is True: # intentionally accept exactly true - if data.check(unit): - convert_from = data.units - data = data.magnitude - else: - raise ValueError( - "Cannot convert quantity from {data.units} " "to {unit}" - ) - else: - raise ValueError("Cannot convert from invalid unit {convert_from}") - - # to make sure we also encounter the case of "equal if converted" - if convert_from is not None: - quantity = (data * convert_from).to( - unit if isinstance(unit, Unit) else unit.dimensionless - ) - else: - try: - quantity = data * unit - except np.core._exceptions.UFuncTypeError: - # from @keewis in xarray.tests.test_units - unsure what this checks? - if unit != 1: - raise - - quantity = data - - return quantity - - def _get_registry(unit_registry, registry_kwargs): if unit_registry is None: if registry_kwargs is None: @@ -106,11 +97,13 @@ def _get_registry(unit_registry, registry_kwargs): return unit_registry -def _decide_units(units, registry, attrs): - if units is None: +def _decide_units(units, registry, unit_attribute): + if units is None and unit_attribute is None: + # or warn and return None? + raise ValueError("no units given") + elif units is None: # TODO option to read and decode units according to CF conventions (see MetPy)? - attr_units = attrs["units"] - units = registry.parse_expression(attr_units) + units = registry.parse_expression(unit_attribute).units elif isinstance(units, Unit): # TODO do we have to check what happens if someone passes a Unit instance # without creating a unit registry? @@ -121,18 +114,6 @@ def _decide_units(units, registry, attrs): return units -def _quantify_variable(var, units): - new_data = _array_attach_units(var.data, units, convert_from=None) - new_var = Variable(dims=var.dims, data=new_data, attrs=var.attrs) - return new_var - - -def _dequantify_variable(var): - new_var = Variable(dims=var.dims, data=var.data.magnitude, attrs=var.attrs) - new_var.attrs["units"] = str(var.data.units) - return new_var - - @register_dataarray_accessor("pint") class PintDataArrayAccessor: """ @@ -144,7 +125,9 @@ class PintDataArrayAccessor: def __init__(self, da): self.da = da - def quantify(self, units=None, unit_registry=None, registry_kwargs=None): + def quantify( + self, units=None, unit_registry=None, registry_kwargs=None, **unit_kwargs + ): """ Attaches units to the DataArray. @@ -154,77 +137,110 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None): `.attrs`. Will raise a ValueError if the DataArray already contains a unit-aware array. + .. note:: + Be aware that unless you're using ``dask`` this will load + the data into memory. To avoid that, consider converting + to ``dask`` first (e.g. using ``chunk``). + + As units in dimension coordinates are not supported until + ``xarray`` changes the way it implements indexes, these + units will be set as attributes. + Parameters ---------- - units : pint.Unit or str, optional - Physical units to use for this DataArray. If not provided, will try - to read them from `DataArray.attrs['units']` using pint's parser. - unit_registry : `pint.UnitRegistry`, optional + units : pint.Unit or str or mapping of hashable to pint.Unit or str, optional + Physical units to use for this DataArray. If a str or + pint.Unit, will be used as the DataArray's units. If a + dict-like, it should map a variable name to the desired + unit (use the DataArray's name to refer to its data). If + not provided, will try to read them from + ``DataArray.attrs['units']`` using pint's parser. The + ``"units"`` attribute will be removed from all variables + except from dimension coordinates. + unit_registry : pint.UnitRegistry, optional Unit registry to be used for the units attached to this DataArray. If not given then a default registry will be created. registry_kwargs : dict, optional - Keyword arguments to be passed to `pint.UnitRegistry`. + Keyword arguments to be passed to the unit registry. + **unit_kwargs + Keyword argument form of units. Returns ------- - quantified - DataArray whose wrapped array data will now be a Quantity - array with the specified units. + quantified : DataArray + DataArray whose wrapped array data will now be a Quantity + array with the specified units. Examples -------- - >>> da.pint.quantify(units='Hz') + >>> da = xr.DataArray( + ... data=[0.4, 0.9, 1.7, 4.8, 3.2, 9.1], + ... dims="frequency", + ... coords={"wavelength": [1e-4, 2e-4, 4e-4, 6e-4, 1e-3, 2e-3]}, + ... ) + >>> da.pint.quantify(units="Hz") Quantity([ 0.4, 0.9, 1.7, 4.8, 3.2, 9.1], 'Hz') Coordinates: * wavelength (wavelength) np.array 1e-4, 2e-4, 4e-4, 6e-4, 1e-3, 2e-3 """ - # TODO should also quantify coordinates (once explicit indexes ready) - if isinstance(self.da.data, Quantity): raise ValueError( f"Cannot attach unit {units} to quantity: data " f"already has units {self.da.data.units}" ) + if isinstance(units, (str, pint.Unit)): + if self.da.name in unit_kwargs: + raise ValueError( + f"ambiguous values given for {repr(self.da.name)}:" + f" {repr(units)} and {repr(unit_kwargs[self.da.name])}" + ) + unit_kwargs[self.da.name] = units + units = None + + units = either_dict_or_kwargs(units, unit_kwargs, ".quantify") + registry = _get_registry(unit_registry, registry_kwargs) - units = _decide_units(units, registry, self.da.attrs) + unit_attrs = conversion.extract_unit_attributes(self.da) + new_obj = conversion.strip_unit_attributes(self.da) + + units = { + name: _decide_units(unit, registry, unit_attribute) + for name, (unit, unit_attribute) in zip_mappings(units, unit_attrs).items() + if unit is not None or unit_attribute is not None + } - quantity = _array_attach_units(self.da.data, units, convert_from=None) + # TODO: remove once indexes support units + dim_units = {name: unit for name, unit in units.items() if name in self.da.dims} + for name in dim_units.keys(): + units.pop(name) + new_obj = conversion.attach_unit_attributes(new_obj, dim_units) - # TODO should we (temporarily) remove the attrs here so that they don't become inconsistent? - return DataArray( - dims=self.da.dims, data=quantity, coords=self.da.coords, attrs=self.da.attrs - ) + return conversion.attach_units(new_obj, units) def dequantify(self): """ - Removes units from the DataArray and it's coordinates. + Removes units from the DataArray and its coordinates. - Will replace `.attrs['units']` on each variable with a string - representation of the `pint.Unit` instance. + Will replace ``.attrs['units']`` on each variable with a string + representation of the ``pint.Unit`` instance. Returns ------- - dequantified - DataArray whose array data is unitless, and of the type - that was previously wrapped by `pint.Quantity`. + dequantified : DataArray + DataArray whose array data is unitless, and of the type + that was previously wrapped by `pint.Quantity`. """ - if not isinstance(self.da.data, Quantity): - raise ValueError( - "Cannot remove units from data that does not have" " units" - ) - - # TODO also dequantify coords (once explicit indexes ready) - da = DataArray( - dims=self.da.dims, - data=self.da.pint.magnitude, - coords=self.da.coords, - attrs=self.da.attrs, + units = units_to_str_or_none(conversion.extract_units(self.da)) + new_obj = conversion.attach_unit_attributes( + conversion.strip_units(self.da), units, ) - da.attrs["units"] = str(self.da.data.units) - return da + + return new_obj @property def magnitude(self): @@ -236,7 +252,7 @@ def units(self): @units.setter def units(self, units): - quantity = _array_attach_units(self.da.data, units) + quantity = conversion.array_attach_units(self.da.data, units) self.da = DataArray( dim=self.da.dims, data=quantity, coords=self.da.coords, attrs=self.da.attrs ) @@ -260,10 +276,10 @@ def to(self, units=None, **unit_kwargs): Parameters ---------- units : str or pint.Unit or mapping of hashable to str or pint.Unit, optional - The units to convert to. If a unit name or - :py:class`pint.Unit` object, convert the DataArray's - data. If a dict-like, it has to map a variable name to a - unit name or :py:class:`pint.Unit` object. + The units to convert to. If a unit name or ``pint.Unit`` + object, convert the DataArray's data. If a dict-like, it + has to map a variable name to a unit name or ``pint.Unit`` + object. **unit_kwargs The kwargs form of ``units``. Can only be used for variable names that are strings and valid python identifiers. @@ -400,69 +416,116 @@ class PintDatasetAccessor: def __init__(self, ds): self.ds = ds - def quantify(self, units=None, unit_registry=None, registry_kwargs=None): + def quantify( + self, units=None, unit_registry=None, registry_kwargs=None, **unit_kwargs + ): """ Attaches units to each variable in the Dataset. - Units can be specified as a pint.Unit or as a string, which will - be parsed by the given unit registry. If no units are specified then - the units will be parsed from the `'units'` entry of the DataArray's - `.attrs`. Will raise a ValueError if any of the DataArrays already - contain a unit-aware array. + Units can be specified as a ``pint.Unit`` or as a + string, which will be parsed by the given unit registry. If no + units are specified then the units will be parsed from the + ``"units"`` entry of the Dataset variable's ``.attrs``. Will + raise a ValueError if any of the variables already contain a + unit-aware array. + + .. note:: + Be aware that unless you're using ``dask`` this will load + the data into memory. To avoid that, consider converting + to ``dask`` first (e.g. using ``chunk``). + + As units in dimension coordinates are not supported until + ``xarray`` changes the way it implements indexes, these + units will be set as attributes. Parameters ---------- - units : mapping from variable names to pint.Unit or str, optional - Physical units to use for particular DataArrays in this Dataset. If - not provided, will try to read them from - `Dataset[var].attrs['units']` using pint's parser. - unit_registry : `pint.UnitRegistry`, optional - Unit registry to be used for the units attached to each DataArray - in this Dataset. If not given then a default registry will be - created. + units : mapping of hashable to pint.Unit or str, optional + Physical units to use for particular DataArrays in this + Dataset. It should map variable names to units (unit names + or ``pint.Unit`` objects). If not provided, will try to + read them from ``Dataset[var].attrs['units']`` using + pint's parser. The ``"units"`` attribute will be removed + from all variables except from dimension coordinates. + unit_registry : pint.UnitRegistry, optional + Unit registry to be used for the units attached to each + DataArray in this Dataset. If not given then a default + registry will be created. registry_kwargs : dict, optional Keyword arguments to be passed to `pint.UnitRegistry`. + **unit_kwargs + Keyword argument form of ``units``. Returns ------- - quantified - Dataset whose variables will now contain Quantity - arrays with units. - """ + quantified : Dataset + Dataset whose variables will now contain Quantity arrays + with units. - for var in self.ds.data_vars: - if isinstance(self.ds[var].data, Quantity): - raise ValueError( - f"Cannot attach unit to quantity: data " - f"variable {var} already has units " - f"{self.ds[var].data.units}" - ) + Examples + -------- + >>> ds = xr.Dataset( + ... {"a": ("x", [0, 3, 2], {"units": "m"}), "b": ("x", 5, -2, 1)}, + ... coords={"x": [0, 1, 2], "u": ("x", [-1, 0, 1], {"units": "s"})}, + ... ) + >>> ds.pint.quantify() + + Dimensions: (x: 3) + Coordinates: + * x (x) int64 0 1 2 + u (x) int64 + Data variables: + a (x) int64 + b (x) int64 5 -2 1 + >>> ds.pint.quantify({"b": "dm"}) + + Dimensions: (x: 3) + Coordinates: + * x (x) int64 0 1 2 + u (x) int64 + Data variables: + a (x) int64 + b (x) int64 + """ + units = either_dict_or_kwargs(units, unit_kwargs, ".quantify") registry = _get_registry(unit_registry, registry_kwargs) - if units is None: - units = {name: None for name in self.ds} + unit_attrs = conversion.extract_unit_attributes(self.ds) + new_obj = conversion.strip_unit_attributes(self.ds) units = { - name: _decide_units(units.get(name, None), registry, var.attrs) - for name, var in self.ds.data_vars.items() + name: _decide_units(unit, registry, attr) + for name, (unit, attr) in zip_mappings(units, unit_attrs).items() + if unit is not None or attr is not None } - quantified_vars = { - name: _quantify_variable(var, units[name]) - for name, var in self.ds.data_vars.items() - } + # TODO: remove once indexes support units + dim_units = {name: unit for name, unit in units.items() if name in new_obj.dims} + for name in dim_units.keys(): + units.pop(name) + new_obj = conversion.attach_unit_attributes(new_obj, dim_units) - # TODO should also quantify coordinates (once explicit indexes ready) - # TODO should we (temporarily) remove the attrs here so that they don't become inconsistent? - return Dataset( - data_vars=quantified_vars, coords=self.ds.coords, attrs=self.ds.attrs - ) + return conversion.attach_units(new_obj, units) def dequantify(self): - dequantified_vars = { - name: da.pint.to_base_units() for name, da in self.ds.items() - } - return Dataset(dequantified_vars, coords=self.ds.coords, attrs=self.ds.attrs) + """ + Removes units from the Dataset and its coordinates. + + Will replace ``.attrs['units']`` on each variable with a string + representation of the ``pint.Unit`` instance. + + Returns + ------- + dequantified : Dataset + Dataset whose data variables are unitless, and of the type + that was previously wrapped by ``pint.Quantity``. + """ + units = units_to_str_or_none(conversion.extract_units(self.ds)) + new_obj = conversion.attach_unit_attributes( + conversion.strip_units(self.ds), units + ) + return new_obj def to(self, units=None, **unit_kwargs): """ convert the quantities in a DataArray diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index 58252105..98465eb7 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -1,3 +1,5 @@ +import itertools + import pint from xarray import DataArray, Dataset, Variable @@ -130,6 +132,30 @@ def attach_units(obj, units, registry=None): return new_obj +def attach_unit_attributes(obj, units, attr="units"): + new_obj = obj.copy() + if isinstance(obj, DataArray): + for name, var in itertools.chain([(obj.name, new_obj)], new_obj.coords.items()): + unit = units.get(name) + if unit is None: + continue + + var.attrs[attr] = unit + elif isinstance(obj, Dataset): + for name, var in new_obj.variables.items(): + unit = units.get(name) + if unit is None: + continue + + var.attrs[attr] = unit + elif isinstance(obj, Variable): + new_obj.attrs[attr] = units.get(None) + else: + raise ValueError(f"cannot attach unit attributes to {obj!r}: unknown type") + + return new_obj + + def convert_units(obj, units): if not isinstance(units, dict): units = {None: units} @@ -196,6 +222,22 @@ def extract_units(obj): return units +def extract_unit_attributes(obj, attr="units"): + if isinstance(obj, DataArray): + variables = itertools.chain([(obj.name, obj)], obj.coords.items()) + units = {name: var.attrs.get(attr, None) for name, var in variables} + elif isinstance(obj, Dataset): + units = {name: var.attrs.get(attr, None) for name, var in obj.variables.items()} + elif isinstance(obj, Variable): + units = {None: obj.attrs.get(attr, None)} + else: + raise ValueError( + f"cannot retrieve unit attributes from unknown type: {type(obj)}" + ) + + return units + + def strip_units(obj): if isinstance(obj, Variable): data = array_strip_units(obj.data) @@ -220,3 +262,22 @@ def strip_units(obj): raise ValueError("cannot strip units from {obj!r}: unknown type") return new_obj + + +def strip_unit_attributes(obj, attr="units"): + new_obj = obj.copy() + if isinstance(obj, DataArray): + variables = itertools.chain([(new_obj.name, new_obj)], new_obj.coords.items()) + for _, var in variables: + var.attrs.pop(attr, None) + elif isinstance(obj, Dataset): + for var in new_obj.variables.values(): + var.attrs.pop(attr, None) + elif isinstance(obj, Variable): + new_obj.attrs.pop(attr, None) + else: + raise ValueError( + f"cannot retrieve unit attributes from unknown type: {type(obj)}" + ) + + return new_obj diff --git a/pint_xarray/tests/test_accessors.py b/pint_xarray/tests/test_accessors.py index e8e5b227..648059fe 100644 --- a/pint_xarray/tests/test_accessors.py +++ b/pint_xarray/tests/test_accessors.py @@ -2,33 +2,55 @@ import pytest import xarray as xr from numpy.testing import assert_array_equal -from pint import UnitRegistry +from pint import Unit, UnitRegistry from pint.errors import UndefinedUnitError from xarray.testing import assert_equal +from .. import conversion from .utils import raises_regex +pytestmark = [ + pytest.mark.filterwarnings("error::pint.UnitStrippedWarning"), +] + # make sure scalars are converted to 0d arrays so quantities can # always be treated like ndarrays unit_registry = UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity -@pytest.fixture() +def assert_all_str_or_none(mapping): + __tracebackhide__ = True + + compared = { + key: isinstance(value, str) or value is None for key, value in mapping.items() + } + not_passing = {key: value for key, value in mapping.items() if not compared[key]} + check = all(compared.values()) + + assert check, f"Not all values are str or None: {not_passing}" + + +@pytest.fixture def example_unitless_da(): array = np.linspace(0, 10, 20) x = np.arange(20) - da = xr.DataArray(data=array, dims="x", coords={"x": x}) - da.attrs["units"] = "m" - da.coords["x"].attrs["units"] = "s" + u = np.linspace(0, 1, 20) + da = xr.DataArray( + data=array, + dims="x", + coords={"x": ("x", x), "u": ("x", u, {"units": "hour"})}, + attrs={"units": "m"}, + ) return da @pytest.fixture() def example_quantity_da(): array = np.linspace(0, 10, 20) * unit_registry.m - x = np.arange(20) * unit_registry.s - return xr.DataArray(data=array, dims="x", coords={"x": x}) + x = np.arange(20) + u = np.linspace(0, 1, 20) * unit_registry.hour + return xr.DataArray(data=array, dims="x", coords={"x": ("x", x), "u": ("x", u)}) class TestQuantifyDataArray: @@ -52,6 +74,9 @@ def test_attach_units_from_attrs(self, example_unitless_da): assert_array_equal(result.data.magnitude, orig.data) assert str(result.data.units) == "meter" + remaining_attrs = conversion.extract_unit_attributes(result) + assert {k: v for k, v in remaining_attrs.items() if v is not None} == {} + def test_attach_units_given_unit_objs(self, example_unitless_da): orig = example_unitless_da ureg = UnitRegistry(force_ndarray=True) @@ -81,14 +106,15 @@ def test_strip_units(self, example_quantity_da): assert isinstance(result.data, np.ndarray) assert isinstance(result.coords["x"].data, np.ndarray) - def test_error_if_no_units(self, example_unitless_da): - with raises_regex(ValueError, "does not have units"): - example_unitless_da.pint.dequantify() - def test_attrs_reinstated(self, example_quantity_da): da = example_quantity_da result = da.pint.dequantify() - assert result.attrs["units"] == "meter" + + units = conversion.extract_units(da) + attrs = conversion.extract_unit_attributes(result) + + assert units == attrs + assert_all_str_or_none(attrs) def test_roundtrip_data(self, example_unitless_da): orig = example_unitless_da @@ -162,6 +188,9 @@ def test_attach_units_from_attrs(self, example_unitless_ds): assert_array_equal(result["users"].data.magnitude, orig["users"].data) assert str(result["users"].data.units) == "dimensionless" + remaining_attrs = conversion.extract_unit_attributes(result) + assert {k: v for k, v in remaining_attrs.items() if v is not None} == {} + def test_attach_units_given_unit_objs(self, example_unitless_ds): orig = example_unitless_ds orig["users"].attrs.clear() @@ -172,7 +201,7 @@ def test_attach_units_given_unit_objs(self, example_unitless_ds): def test_error_when_already_units(self, example_quantity_ds): with raises_regex(ValueError, "already has units"): - example_quantity_ds.pint.quantify() + example_quantity_ds.pint.quantify({"funds": "pounds"}) def test_error_on_nonsense_units(self, example_unitless_ds): ds = example_unitless_ds @@ -180,9 +209,39 @@ def test_error_on_nonsense_units(self, example_unitless_ds): ds.pint.quantify(units={"users": "aecjhbav"}) -@pytest.mark.skip(reason="Not yet implemented") class TestDequantifyDataSet: - ... + def test_strip_units(self, example_quantity_ds): + result = example_quantity_ds.pint.dequantify() + + assert all( + isinstance(var.data, np.ndarray) for var in result.variables.values() + ) + + def test_attrs_reinstated(self, example_quantity_ds): + ds = example_quantity_ds + result = ds.pint.dequantify() + + units = conversion.extract_units(ds) + # workaround for Unit("dimensionless") != str(Unit("dimensionless")) + units = { + key: str(value) if isinstance(value, Unit) else value + for key, value in units.items() + } + + attrs = conversion.extract_unit_attributes(result) + + assert units == attrs + assert_all_str_or_none(attrs) + + def test_roundtrip_data(self, example_unitless_ds): + orig = example_unitless_ds + quantified = orig.pint.quantify() + + result = quantified.pint.dequantify() + assert_equal(result, orig) + + result = quantified.pint.dequantify().pint.quantify() + assert_equal(quantified, result) @pytest.mark.skip(reason="Not yet implemented") diff --git a/pint_xarray/tests/test_conversion.py b/pint_xarray/tests/test_conversion.py index e77366d2..f79dc405 100644 --- a/pint_xarray/tests/test_conversion.py +++ b/pint_xarray/tests/test_conversion.py @@ -12,6 +12,10 @@ pytestmark = pytest.mark.filterwarnings("error::pint.UnitStrippedWarning") +def filter_none_values(mapping): + return {k: v for k, v in mapping.items() if v is not None} + + class TestArrayFunctions: @pytest.mark.parametrize( "registry", @@ -176,6 +180,29 @@ def test_attach_units(self, obj, units): assert conversion.extract_units(actual) == units + @pytest.mark.parametrize( + ["obj", "units"], + ( + pytest.param( + DataArray(dims="x", coords={"x": [], "u": ("x", [])}), + {None: "hPa", "x": "m"}, + id="DataArray", + ), + pytest.param( + Dataset( + data_vars={"a": ("x", []), "b": ("x", [])}, + coords={"x": [], "u": ("x", [])}, + ), + {"a": "K", "b": "hPa", "u": "m"}, + id="Dataset", + ), + pytest.param(Variable("x", []), {None: "hPa"}, id="Variable",), + ), + ) + def test_attach_unit_attributes(self, obj, units): + actual = conversion.attach_unit_attributes(obj, units) + assert units == filter_none_values(conversion.extract_unit_attributes(actual)) + @pytest.mark.parametrize( "variant", ( @@ -309,6 +336,44 @@ def test_extract_units(self, typename, units): assert conversion.extract_units(obj) == units + @pytest.mark.parametrize( + ["obj", "expected"], + ( + pytest.param( + DataArray( + coords={ + "x": ("x", [], {"units": "m"}), + "u": ("x", [], {"units": "s"}), + }, + attrs={"units": "hPa"}, + dims="x", + ), + {"x": "m", "u": "s", None: "hPa"}, + id="DataArray", + ), + pytest.param( + Dataset( + data_vars={ + "a": ("x", [], {"units": "K"}), + "b": ("x", [], {"units": "hPa"}), + }, + coords={ + "x": ("x", [], {"units": "m"}), + "u": ("x", [], {"units": "s"}), + }, + ), + {"a": "K", "b": "hPa", "x": "m", "u": "s"}, + id="Dataset", + ), + pytest.param( + Variable("x", [], {"units": "hPa"}), {None: "hPa"}, id="Variable", + ), + ), + ) + def test_extract_unit_attributes(self, obj, expected): + actual = conversion.extract_unit_attributes(obj) + assert expected == actual + @pytest.mark.parametrize( "obj", ( @@ -344,3 +409,45 @@ def test_strip_units(self, obj): actual = conversion.strip_units(obj) assert conversion.extract_units(actual) == expected_units + + @pytest.mark.parametrize( + ["obj", "expected"], + ( + pytest.param( + DataArray( + coords={ + "x": ("x", [], {"units": "m"}), + "u": ("x", [], {"units": "s"}), + }, + attrs={"units": "hPa"}, + dims="x", + ), + {"x": "m", "u": "s", None: "hPa"}, + id="DataArray", + ), + pytest.param( + Dataset( + data_vars={ + "a": ("x", [], {"units": "K"}), + "b": ("x", [], {"units": "hPa"}), + }, + coords={ + "x": ("x", [], {"units": "m"}), + "u": ("x", [], {"units": "s"}), + }, + ), + {"a": "K", "b": "hPa", "x": "m", "u": "s"}, + id="Dataset", + ), + pytest.param( + Variable("x", [], {"units": "hPa"}), {None: "hPa"}, id="Variable", + ), + ), + ) + def test_strip_unit_attributes(self, obj, expected): + actual = conversion.strip_unit_attributes(obj) + expected = {} + + assert ( + filter_none_values(conversion.extract_unit_attributes(actual)) == expected + )