Skip to content

Commit

Permalink
initial fix for #304
Browse files Browse the repository at this point in the history
  • Loading branch information
pochedls committed Aug 18, 2022
1 parent afd7a82 commit 428675f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def test_weighted_monthly_averages_with_masked_data(self):
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[2.0]], [[0.0]], [[1.0]], [[1.0]], [[2.0]]]),
data=np.array([[[2.0]], [[np.nan]], [[1.0]], [[1.0]], [[2.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
Expand Down
24 changes: 21 additions & 3 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,10 +919,28 @@ def _group_average(self, data_var: xr.DataArray) -> xr.DataArray:

if self._weighted:
self._weights = self._get_weights()
# grab metadata
dv_attrs = dv.attrs
dv_name = dv.name
# grab metadata (populated after calling ._group_data())
lt_attrs = self._labeled_time.attrs
lt_encoding = self._labeled_time.encoding
# weight the data variable
dv *= self._weights
dv = self._group_data(dv).sum()
# cast the weights to match the dv.shape / dims
weights, x = xr.broadcast(self._weights, dv)
# ensure missing data receives no weight
weights = xr.where(np.isnan(dv), 0.0, weights)
# perform weighted average
dv = self._group_data(dv).sum() / self._group_data(weights).sum()
# add dv attributes
dv.attrs = dv_attrs
dv.name = dv_name
else:
dv = self._group_data(dv).mean()
# grab metadata (populated after calling ._group_data())
lt_attrs = self._labeled_time.attrs
lt_encoding = self._labeled_time.encoding

# After grouping and aggregating the data variable values, the
# original time dimension is replaced with the grouped time dimension.
Expand All @@ -936,8 +954,8 @@ def _group_average(self, data_var: xr.DataArray) -> xr.DataArray:
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
# attributes for data variables and not their coordinates, so the
# coordinate attributes have to be restored manually.
dv[self._dim].attrs = self._labeled_time.attrs
dv[self._dim].encoding = self._labeled_time.encoding
dv[self._dim].attrs = lt_attrs
dv[self._dim].encoding = lt_encoding

dv = self._add_operation_attrs(dv)

Expand Down

0 comments on commit 428675f

Please sign in to comment.