Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Add transpose_coords option to DataArray.transpose (pydata#2556)
Browse files Browse the repository at this point in the history
* Add transpose_coords option to DataArray.transpose
Fixes pydata#1856

* Fix typo

* Fix bug in transpose
Fix python 2 compatibility

* Set default for transpose_coords to None
Update documentation

* Fix bug in coordinate tranpose
Update documentation

* Suppress FutureWarning in tests

* Add restore_coord_dims parameter to DataArrayGroupBy.apply

* Move restore_coord_dims parameter to GroupBy class

* Remove restore_coord_dims parameter from DataArrayResample.apply

* Update whats-new

* Update whats-new
  • Loading branch information
phausamann authored and shoyer committed May 21, 2019
1 parent 6658108 commit 0811141
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 33 deletions.
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ Enhancements
- Character arrays' character dimension name decoding and encoding handled by
``var.encoding['char_dim_name']`` (:issue:`2895`)
By `James McCreight <https://github.com/jmccreight>`_.
- :py:meth:`DataArray.transpose` now accepts a keyword argument
``transpose_coords`` which enables transposition of coordinates in the
same way as :py:meth:`Dataset.transpose`. :py:meth:`DataArray.groupby`
:py:meth:`DataArray.groupby_bins`, and :py:meth:`DataArray.resample` now
accept a keyword argument ``restore_coord_dims`` which keeps the order
of the dimensions of multi-dimensional coordinates intact (:issue:`1856`).
By `Peter Hausamann <http://github.com/phausamann>`_.
- Clean up Python 2 compatibility in code (:issue:`2950`)
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Implement ``load_dataset()`` and ``load_dataarray()`` as alternatives to
Expand Down
27 changes: 21 additions & 6 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]],
else:
return func(self, *args, **kwargs)

def groupby(self, group, squeeze: bool = True):
def groupby(self, group, squeeze: bool = True,
restore_coord_dims: Optional[bool] = None):
"""Returns a GroupBy object for performing grouped operations.
Parameters
Expand All @@ -453,6 +454,9 @@ def groupby(self, group, squeeze: bool = True):
If "group" is a dimension of any arrays in this dataset, `squeeze`
controls whether the subarrays have a dimension of length 1 along
that dimension or if the dimension is squeezed out.
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
Returns
-------
Expand Down Expand Up @@ -485,11 +489,13 @@ def groupby(self, group, squeeze: bool = True):
core.groupby.DataArrayGroupBy
core.groupby.DatasetGroupBy
""" # noqa
return self._groupby_cls(self, group, squeeze=squeeze)
return self._groupby_cls(self, group, squeeze=squeeze,
restore_coord_dims=restore_coord_dims)

def groupby_bins(self, group, bins, right: bool = True, labels=None,
precision: int = 3, include_lowest: bool = False,
squeeze: bool = True):
squeeze: bool = True,
restore_coord_dims: Optional[bool] = None):
"""Returns a GroupBy object for performing grouped operations.
Rather than using all unique values of `group`, the values are discretized
Expand Down Expand Up @@ -522,6 +528,9 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None,
If "group" is a dimension of any arrays in this dataset, `squeeze`
controls whether the subarrays have a dimension of length 1 along
that dimension or if the dimension is squeezed out.
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
Returns
-------
Expand All @@ -536,9 +545,11 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None,
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
""" # noqa
return self._groupby_cls(self, group, squeeze=squeeze, bins=bins,
restore_coord_dims=restore_coord_dims,
cut_kwargs={'right': right, 'labels': labels,
'precision': precision,
'include_lowest': include_lowest})
'include_lowest':
include_lowest})

def rolling(self, dim: Optional[Mapping[Hashable, int]] = None,
min_periods: Optional[int] = None, center: bool = False,
Expand Down Expand Up @@ -669,7 +680,7 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
skipna=None, closed: Optional[str] = None,
label: Optional[str] = None,
base: int = 0, keep_attrs: Optional[bool] = None,
loffset=None,
loffset=None, restore_coord_dims: Optional[bool] = None,
**indexer_kwargs: str):
"""Returns a Resample object for performing resampling operations.
Expand Down Expand Up @@ -697,6 +708,9 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
object will be returned without attributes.
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
**indexer_kwargs : {dim: freq}
The keyword arguments form of ``indexer``.
One of indexer or indexer_kwargs must be provided.
Expand Down Expand Up @@ -786,7 +800,8 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
dims=dim_coord.dims, name=RESAMPLE_DIM)
resampler = self._resample_cls(self, group=group, dim=dim_name,
grouper=grouper,
resample_dim=RESAMPLE_DIM)
resample_dim=RESAMPLE_DIM,
restore_coord_dims=restore_coord_dims)

return resampler

Expand Down
26 changes: 24 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,14 +1405,16 @@ def unstack(self, dim=None):
ds = self._to_temp_dataset().unstack(dim)
return self._from_temp_dataset(ds)

def transpose(self, *dims) -> 'DataArray':
def transpose(self, *dims, transpose_coords=None) -> 'DataArray':
"""Return a new DataArray object with transposed dimensions.
Parameters
----------
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
transpose_coords : boolean, optional
If True, also transpose the coordinates of this DataArray.
Returns
-------
Expand All @@ -1430,8 +1432,28 @@ def transpose(self, *dims) -> 'DataArray':
numpy.transpose
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError('arguments to transpose (%s) must be '
'permuted array dimensions (%s)'
% (dims, tuple(self.dims)))

variable = self.variable.transpose(*dims)
return self._replace(variable)
if transpose_coords:
coords = {}
for name, coord in self.coords.items():
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
return self._replace(variable, coords)
else:
if transpose_coords is None \
and any(self[c].ndim > 1 for c in self.coords):
warnings.warn('This DataArray contains multi-dimensional '
'coordinates. In the future, these coordinates '
'will be transposed as well unless you specify '
'transpose_coords=False.',
FutureWarning, stacklevel=2)
return self._replace(variable)

@property
def T(self) -> 'DataArray':
Expand Down
25 changes: 20 additions & 5 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class GroupBy(SupportsArithmetic):
"""

def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
cut_kwargs={}):
restore_coord_dims=None, cut_kwargs={}):
"""Create a GroupBy object
Parameters
Expand All @@ -215,6 +215,9 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
bins : array-like, optional
If `bins` is specified, the groups will be discretized into the
specified bins by `pandas.cut`.
restore_coord_dims : bool, optional
If True, also restore the dimension order of multi-dimensional
coordinates.
cut_kwargs : dict, optional
Extra keyword arguments to pass to `pandas.cut`
Expand Down Expand Up @@ -279,6 +282,16 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
safe_cast_to_index(group), sort=(bins is None))
unique_coord = IndexVariable(group.name, unique_values)

if isinstance(obj, DataArray) \
and restore_coord_dims is None \
and any(obj[c].ndim > 1 for c in obj.coords):
warnings.warn('This DataArray contains multi-dimensional '
'coordinates. In the future, the dimension order '
'of these coordinates will be restored as well '
'unless you specify restore_coord_dims=False.',
FutureWarning, stacklevel=2)
restore_coord_dims = False

# specification for the groupby operation
self._obj = obj
self._group = group
Expand All @@ -288,6 +301,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
self._stacked_dim = stacked_dim
self._inserted_dims = inserted_dims
self._full_index = full_index
self._restore_coord_dims = restore_coord_dims

# cached attributes
self._groups = None
Expand Down Expand Up @@ -508,7 +522,8 @@ def lookup_order(dimension):
return axis

new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order)
return stacked.transpose(
*new_order, transpose_coords=self._restore_coord_dims)

def apply(self, func, shortcut=False, args=(), **kwargs):
"""Apply a function over each array in the group and concatenate them
Expand Down Expand Up @@ -558,7 +573,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs):
for arr in grouped)
return self._combine(applied, shortcut=shortcut)

def _combine(self, applied, shortcut=False):
def _combine(self, applied, restore_coord_dims=False, shortcut=False):
"""Recombine the applied objects like the original."""
applied_example, applied = peek_at(applied)
coord, dim, positions = self._infer_concat_args(applied_example)
Expand All @@ -580,8 +595,8 @@ def _combine(self, applied, shortcut=False):
combined = self._maybe_unstack(combined)
return combined

def reduce(self, func, dim=None, axis=None,
keep_attrs=None, shortcut=True, **kwargs):
def reduce(self, func, dim=None, axis=None, keep_attrs=None,
shortcut=True, **kwargs):
"""Reduce the items in this group by applying `func` along some
dimension(s).
Expand Down
14 changes: 9 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def _infer_line_data(darray, x, y, hue):
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
yplt = darray.transpose(otherdim, huename)
xplt = xplt.transpose(otherdim, huename)
yplt = darray.transpose(
otherdim, huename, transpose_coords=False)
xplt = xplt.transpose(
otherdim, huename, transpose_coords=False)
else:
raise ValueError('For 2D inputs, hue must be a dimension'
+ ' i.e. one of ' + repr(darray.dims))
Expand All @@ -79,7 +81,9 @@ def _infer_line_data(darray, x, y, hue):
if yplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
xplt = darray.transpose(otherdim, huename)
otherdim = darray.dims[otherindex]
xplt = darray.transpose(
otherdim, huename, transpose_coords=False)
else:
raise ValueError('For 2D inputs, hue must be a dimension'
+ ' i.e. one of ' + repr(darray.dims))
Expand Down Expand Up @@ -614,9 +618,9 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims)
darray = darray.transpose(*dims, transpose_coords=True)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose()
darray = darray.transpose(transpose_coords=True)

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)
Expand Down
57 changes: 48 additions & 9 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,14 +1681,14 @@ def test_math_with_coords(self):
assert_identical(expected, actual)

actual = orig[0, :] + orig[:, 0]
assert_identical(expected.T, actual)
assert_identical(expected.transpose(transpose_coords=True), actual)

actual = orig - orig.T
actual = orig - orig.transpose(transpose_coords=True)
expected = DataArray(np.zeros((2, 3)), orig.coords)
assert_identical(expected, actual)

actual = orig.T - orig
assert_identical(expected.T, actual)
actual = orig.transpose(transpose_coords=True) - orig
assert_identical(expected.transpose(transpose_coords=True), actual)

alt = DataArray([1, 1], {'x': [-1, -2], 'c': 'foo', 'd': 555}, 'x')
actual = orig + alt
Expand Down Expand Up @@ -1801,8 +1801,27 @@ def test_stack_nonunique_consistency(self):
assert_identical(expected, actual)

def test_transpose(self):
assert_equal(self.dv.variable.transpose(),
self.dv.transpose().variable)
da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'),
coords={'x': range(3), 'y': range(4), 'z': range(5),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = da.transpose(transpose_coords=False)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords=da.coords)
assert_equal(expected, actual)

actual = da.transpose('z', 'y', 'x', transpose_coords=True)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords={'x': da.x.values, 'y': da.y.values,
'z': da.z.values,
'xy': (('y', 'x'), da.xy.values.T)})
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose('x', 'y')

with pytest.warns(FutureWarning):
da.transpose()

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
Expand Down Expand Up @@ -2258,6 +2277,23 @@ def test_groupby_restore_dim_order(self):
result = array.groupby(by).apply(lambda x: x.squeeze())
assert result.dims == expected_dims

def test_groupby_restore_coord_dims(self):
array = DataArray(np.random.randn(5, 3),
coords={'a': ('x', range(5)), 'b': ('y', range(3)),
'c': (('x', 'y'), np.random.randn(5, 3))},
dims=['x', 'y'])

for by, expected_dims in [('x', ('x', 'y')),
('y', ('x', 'y')),
('a', ('a', 'y')),
('b', ('x', 'b'))]:
result = array.groupby(by, restore_coord_dims=True).apply(
lambda x: x.squeeze())['c']
assert result.dims == expected_dims

with pytest.warns(FutureWarning):
array.groupby('x').apply(lambda x: x.squeeze())

def test_groupby_first_and_last(self):
array = DataArray([1, 2, 3, 4, 5], dims='x')
by = DataArray(['a'] * 2 + ['b'] * 3, dims='x', name='ab')
Expand Down Expand Up @@ -2445,15 +2481,18 @@ def test_resample_drop_nondim_coords(self):
array = ds['data']

# Re-sample
actual = array.resample(time="12H").mean('time')
actual = array.resample(
time="12H", restore_coord_dims=True).mean('time')
assert 'tc' not in actual.coords

# Up-sample - filling
actual = array.resample(time="1H").ffill()
actual = array.resample(
time="1H", restore_coord_dims=True).ffill()
assert 'tc' not in actual.coords

# Up-sample - interpolation
actual = array.resample(time="1H").interpolate('linear')
actual = array.resample(
time="1H", restore_coord_dims=True).interpolate('linear')
assert 'tc' not in actual.coords

def test_resample_keep_attrs(self):
Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4062,14 +4062,20 @@ def test_dataset_math_errors(self):

def test_dataset_transpose(self):
ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)),
'b': (('y', 'x'), np.random.randn(4, 3))})
'b': (('y', 'x'), np.random.randn(4, 3))},
coords={'x': range(3), 'y': range(4),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = ds.transpose()
expected = ds.apply(lambda x: x.transpose())
expected = Dataset({'a': (('y', 'x'), ds.a.values.T),
'b': (('x', 'y'), ds.b.values.T)},
coords={'x': ds.x.values, 'y': ds.y.values,
'xy': (('y', 'x'), ds.xy.values.T)})
assert_identical(expected, actual)

actual = ds.transpose('x', 'y')
expected = ds.apply(lambda x: x.transpose('x', 'y'))
expected = ds.apply(
lambda x: x.transpose('x', 'y', transpose_coords=True))
assert_identical(expected, actual)

ds = create_test_data()
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def func(obj, dim, new_x):
'y': da['y'],
'x': ('z', xdest.values),
'x2': ('z', func(da['x2'], 'x', xdest))})
assert_allclose(actual, expected.transpose('z', 'y'))
assert_allclose(actual,
expected.transpose('z', 'y', transpose_coords=True))

# xdest is 2d
xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5),
Expand All @@ -160,7 +161,8 @@ def func(obj, dim, new_x):
coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'],
'y': da['y'], 'x': (('z', 'w'), xdest),
'x2': (('z', 'w'), func(da['x2'], 'x', xdest))})
assert_allclose(actual, expected.transpose('z', 'w', 'y'))
assert_allclose(actual,
expected.transpose('z', 'w', 'y', transpose_coords=True))


@pytest.mark.parametrize('case', [3, 4])
Expand Down
Loading

0 comments on commit 0811141

Please sign in to comment.