From 39f610fcbb80c783d1911a330f5b72cfca922549 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 18 May 2020 20:16:33 +0200 Subject: [PATCH 1/5] add tests --- xarray/tests/test_weighted.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 24531215dfb..1bf685cc95d 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -59,6 +59,18 @@ def test_weighted_sum_of_weights_nan(weights, expected): assert_equal(expected, result) +def test_weighted_sum_of_weights_bool(): + # https://github.com/pydata/xarray/issues/4074 + + da = DataArray([1, 2]) + weights = DataArray([True, True]) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(2) + + assert_equal(expected, result) + + @pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize("factor", [0, 1, 3.14]) @pytest.mark.parametrize("skipna", (True, False)) @@ -158,6 +170,17 @@ def test_weighted_mean_nan(weights, expected, skipna): assert_equal(expected, result) +def test_weighted_mean_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(1) + + result = da.weighted(weights).mean() + + assert_equal(expected, result) + + def expected_weighted(da, weights, dim, skipna, operation): """ Generate expected result using ``*`` and ``sum``. This is checked against From 44d8284ba2f173b56d920233ff2a9c8a4a6fc118 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 18 May 2020 20:40:04 +0200 Subject: [PATCH 2/5] weights: bool -> int --- xarray/core/weighted.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 996d2e4c43e..9f534c7dd4c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -142,7 +142,13 @@ def _sum_of_weights( # we need to mask data values that are nan; else the weights are wrong mask = da.notnull() - sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) + weights = self.weights + # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True + # (and not 2) GH4074 + if weights.dtype == bool: + weights = weights.astype(int) + + sum_of_weights = self._reduce(mask, weights, dim=dim, skipna=False) # 0-weights are not valid valid_weights = sum_of_weights != 0.0 From 4ca6fbe73edf6a223cafca961ef347c0bf8f3795 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 18 May 2020 20:40:23 +0200 Subject: [PATCH 3/5] whats new --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cd30fab0160..9a62a4ff42c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -110,6 +110,8 @@ Bug fixes - Fix bug in time parsing failing to fall back to cftime. This was causing time variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`) By `Ryan May `_. +- Fix weighted mean when passing boolean weights (:issue:`4074`). + By `Mathias Hauser `_ Documentation ~~~~~~~~~~~~~ @@ -175,7 +177,7 @@ New Features - Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). - By `Mathias Hauser `_ + By `Mathias Hauser `_. - The new jupyter notebook repr (``Dataset._repr_html_`` and ``DataArray._repr_html_``) (introduced in 0.14.1) is now on by default. To disable, use ``xarray.set_options(display_style="text")``. From a09328ac0df3845590191854eedfa925b168ecd5 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 18 May 2020 20:43:03 +0200 Subject: [PATCH 4/5] Apply suggestions from code review --- doc/whats-new.rst | 2 +- xarray/core/weighted.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a62a4ff42c..f0fede28ebb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -111,7 +111,7 @@ Bug fixes variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`) By `Ryan May `_. - Fix weighted mean when passing boolean weights (:issue:`4074`). - By `Mathias Hauser `_ + By `Mathias Hauser `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 9f534c7dd4c..19718f5f575 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -144,7 +144,7 @@ def _sum_of_weights( weights = self.weights # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True - # (and not 2) GH4074 + # (and not 2); GH4074 if weights.dtype == bool: weights = weights.astype(int) From f65e2e033099e640af81391f7862808ccd20de40 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 18 May 2020 21:12:40 +0200 Subject: [PATCH 5/5] avoid unecessary copy --- xarray/core/weighted.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 19718f5f575..21ed06ea85f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -142,13 +142,14 @@ def _sum_of_weights( # we need to mask data values that are nan; else the weights are wrong mask = da.notnull() - weights = self.weights # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True # (and not 2); GH4074 - if weights.dtype == bool: - weights = weights.astype(int) - - sum_of_weights = self._reduce(mask, weights, dim=dim, skipna=False) + if self.weights.dtype == bool: + sum_of_weights = self._reduce( + mask, self.weights.astype(int), dim=dim, skipna=False + ) + else: + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) # 0-weights are not valid valid_weights = sum_of_weights != 0.0