diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 447aaf5b0bf..a4602c1edad 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -119,6 +119,8 @@ Bug fixes - Fix bug in time parsing failing to fall back to cftime. This was causing time variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`) By `Ryan May `_. +- Fix weighted mean when passing boolean weights (:issue:`4074`). + By `Mathias Hauser `_. - Fix html repr in untrusted notebooks: fallback to plain text repr. (:pull:`4053`) By `Benoit Bovy `_. @@ -186,7 +188,7 @@ New Features - Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). - By `Mathias Hauser `_ + By `Mathias Hauser `_. - The new jupyter notebook repr (``Dataset._repr_html_`` and ``DataArray._repr_html_``) (introduced in 0.14.1) is now on by default. To disable, use ``xarray.set_options(display_style="text")``. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 996d2e4c43e..21ed06ea85f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -142,7 +142,14 @@ def _sum_of_weights( # we need to mask data values that are nan; else the weights are wrong mask = da.notnull() - sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) + # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True + # (and not 2); GH4074 + if self.weights.dtype == bool: + sum_of_weights = self._reduce( + mask, self.weights.astype(int), dim=dim, skipna=False + ) + else: + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) # 0-weights are not valid valid_weights = sum_of_weights != 0.0 diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 24531215dfb..1bf685cc95d 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -59,6 +59,18 @@ def test_weighted_sum_of_weights_nan(weights, expected): assert_equal(expected, result) +def test_weighted_sum_of_weights_bool(): + # https://github.com/pydata/xarray/issues/4074 + + da = DataArray([1, 2]) + weights = DataArray([True, True]) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(2) + + assert_equal(expected, result) + + @pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize("factor", [0, 1, 3.14]) @pytest.mark.parametrize("skipna", (True, False)) @@ -158,6 +170,17 @@ def test_weighted_mean_nan(weights, expected, skipna): assert_equal(expected, result) +def test_weighted_mean_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(1) + + result = da.weighted(weights).mean() + + assert_equal(expected, result) + + def expected_weighted(da, weights, dim, skipna, operation): """ Generate expected result using ``*`` and ``sum``. This is checked against