Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose_coords option to DataArray.transpose #2556

Merged
merged 20 commits into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Breaking changes
Enhancements
~~~~~~~~~~~~

- :py:meth:`DataArray.transpose` now accepts a keyword argument
``transpose_coords`` which enables transposition of coordinates in the
same way as :py:meth:`Dataset.transpose`.
By `Peter Hausamann <http://github.com/phausamann>`_.

Bug fixes
~~~~~~~~~

Expand Down
27 changes: 25 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,14 +1357,16 @@ def unstack(self, dim=None):
ds = self._to_temp_dataset().unstack(dim)
return self._from_temp_dataset(ds)

def transpose(self, *dims):
def transpose(self, *dims, **kwargs):
"""Return a new DataArray object with transposed dimensions.

Parameters
----------
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
transpose_coords : boolean, optional
If True, also transpose the coordinates of this DataArray.

Returns
-------
Expand All @@ -1381,8 +1383,29 @@ def transpose(self, *dims):
numpy.transpose
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError('arguments to transpose (%s) must be '
'permuted array dimensions (%s)'
% (dims, tuple(self.dims)))

transpose_coords = kwargs.pop('transpose_coords', None)
phausamann marked this conversation as resolved.
Show resolved Hide resolved
variable = self.variable.transpose(*dims)
return self._replace(variable)
if transpose_coords:
coords = {}
for name, coord in iteritems(self.coords):
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
return self._replace(variable, coords)
else:
if transpose_coords is None \
and any(self[c].ndim > 1 for c in self.coords):
warnings.warn('This DataArray contains multi-dimensional '
'coordinates. In the future, these coordinates '
'will be transposed as well unless you specify '
'transpose_coords=False.',
FutureWarning, stacklevel=2)
return self._replace(variable)

def drop(self, labels, dim=None):
"""Drop coordinates or index labels from this DataArray.
Expand Down
20 changes: 18 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,8 +1709,24 @@ def test_stack_nonunique_consistency(self):
assert_identical(expected, actual)

def test_transpose(self):
assert_equal(self.dv.variable.transpose(),
self.dv.transpose().variable)
da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'),
coords={'x': range(3), 'y': range(4), 'z': range(5),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = da.transpose(transpose_coords=False)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords=da.coords)
assert_equal(expected, actual)

actual = da.transpose('z', 'y', 'x', transpose_coords=True)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords={'x': da.x.values, 'y': da.y.values,
'z': da.z.values,
'xy': (('y', 'x'), da.xy.values.T)})
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose('x', 'y')

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
Expand Down
9 changes: 7 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3799,10 +3799,15 @@ def test_dataset_math_errors(self):

def test_dataset_transpose(self):
ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)),
'b': (('y', 'x'), np.random.randn(4, 3))})
'b': (('y', 'x'), np.random.randn(4, 3))},
coords={'x': range(3), 'y': range(4),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = ds.transpose()
expected = ds.apply(lambda x: x.transpose())
expected = Dataset({'a': (('y', 'x'), ds.a.values.T),
'b': (('x', 'y'), ds.b.values.T)},
coords={'x': ds.x.values, 'y': ds.y.values,
'xy': (('y', 'x'), ds.xy.values.T)})
assert_identical(expected, actual)

actual = ds.transpose('x', 'y')
Expand Down