Skip to content

Commit

Permalink
Fix name loss when masking (#2749)
Browse files Browse the repository at this point in the history
* fix renaming

* formatting

* added tests

* shoyer's solution

* what's new
  • Loading branch information
yohai authored and shoyer committed Feb 11, 2019
1 parent 4cd56a9 commit 07cfc5a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ Bug fixes
- Fix ``open_rasterio`` creating a WKT CRS instead of PROJ.4 with
``rasterio`` 1.0.14+ (:issue:`2715`).
By `David Hoese <https://github.com/djhoese`_.

- Masking data arrays with :py:meth:`xarray.DataArray.where` now returns an
array with the name of the original masked array (:issue:`2748` and :issue:`2457`).
By `Yohai Bar-Sinai <https://github.com/yohai>`_.
.. _whats-new.0.11.3:

v0.11.3 (26 January 2019)
Expand Down
9 changes: 7 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
signature = kwargs.pop('signature')
join = kwargs.pop('join', 'inner')
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET)
keep_attrs = kwargs.pop('keep_attrs', True)
if kwargs:
raise TypeError('apply_dataarray_ufunc() got unexpected keyword '
'arguments: %s' % list(kwargs))
Expand All @@ -207,7 +208,10 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
args = deep_align(args, join=join, copy=False, exclude=exclude_dims,
raise_on_invalid=False)

name = result_name(args)
if keep_attrs and hasattr(args[0], 'name'):
name = args[0].name
else:
name = result_name(args)
result_coords = build_output_coords(args, signature, exclude_dims)

data_vars = [getattr(a, 'variable', a) for a in args]
Expand Down Expand Up @@ -985,7 +989,8 @@ def earth_mover_distance(first_samples,
return apply_dataarray_ufunc(variables_ufunc, *args,
signature=signature,
join=join,
exclude_dims=exclude_dims)
exclude_dims=exclude_dims,
keep_attrs=keep_attrs)
elif any(isinstance(a, Variable) for a in args):
return variables_ufunc(*args)
else:
Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3694,6 +3694,15 @@ def test_raise_no_warning_for_nan_in_binary_ops():
assert len(record) == 0


def test_name_in_masking():
name = 'RingoStarr'
da = xr.DataArray(range(10), coords=[('x', range(10))], name=name)
assert da.where(da > 5).name == name
assert da.where((da > 5).rename('YokoOno')).name == name
assert da.where(da > 5, drop=True).name == name
assert da.where((da > 5).rename('YokoOno'), drop=True).name == name


class TestIrisConversion(object):
@requires_iris
def test_to_and_from_iris(self):
Expand Down

0 comments on commit 07cfc5a

Please sign in to comment.