Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose_coords option to DataArray.transpose #2556

Merged
merged 20 commits into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Breaking changes
Enhancements
~~~~~~~~~~~~

- :py:meth:`DataArray.transpose` now accepts a keyword argument
``transpose_coords`` which enables transposition of coordinates in the
same way as :py:meth:`Dataset.transpose`.
By `Peter Hausamann <http://github.com/phausamann>`_.

Bug fixes
~~~~~~~~~

Expand Down
30 changes: 28 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,14 +1357,16 @@ def unstack(self, dim=None):
ds = self._to_temp_dataset().unstack(dim)
return self._from_temp_dataset(ds)

def transpose(self, *dims):
def transpose(self, *dims, **kwargs):
"""Return a new DataArray object with transposed dimensions.

Parameters
----------
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
transpose_coords : boolean, optional
If True, also transpose the coordinates of this DataArray.

Returns
-------
Expand All @@ -1381,8 +1383,32 @@ def transpose(self, *dims):
numpy.transpose
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError('arguments to transpose (%s) must be '
'permuted array dimensions (%s)'
% (dims, tuple(self.dims)))

transpose_coords = kwargs.pop('transpose_coords', None)
phausamann marked this conversation as resolved.
Show resolved Hide resolved
if kwargs:
raise ValueError(
'Invalid keyword arguments: %s' % tuple(k for k in kwargs))
variable = self.variable.transpose(*dims)
return self._replace(variable)
if transpose_coords:
coords = {}
for name, coord in iteritems(self.coords):
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
return self._replace(variable, coords)
else:
if transpose_coords is None \
and any(self[c].ndim > 1 for c in self.coords):
warnings.warn('This DataArray contains multi-dimensional '
'coordinates. In the future, these coordinates '
'will be transposed as well unless you specify '
'transpose_coords=False.',
FutureWarning, stacklevel=2)
return self._replace(variable)

def drop(self, labels, dim=None):
"""Drop coordinates or index labels from this DataArray.
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def lookup_order(dimension):
return axis

new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order)
return stacked.transpose(*new_order, transpose_coords=True)
phausamann marked this conversation as resolved.
Show resolved Hide resolved

def apply(self, func, shortcut=False, **kwargs):
"""Apply a function over each array in the group and concatenate them
Expand Down
13 changes: 8 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def _infer_line_data(darray, x, y, hue):
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
yplt = darray.transpose(otherdim, huename)
xplt = xplt.transpose(otherdim, huename)
yplt = darray.transpose(
otherdim, huename, transpose_coords=True)
xplt = xplt.transpose(
otherdim, huename, transpose_coords=True)
else:
raise ValueError('For 2D inputs, hue must be a dimension'
+ ' i.e. one of ' + repr(darray.dims))
Expand All @@ -244,13 +246,14 @@ def _infer_line_data(darray, x, y, hue):
if yplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
xplt = darray.transpose(otherdim, huename)
xplt = darray.transpose(
otherdim, huename, transpose_coords=True)
else:
raise ValueError('For 2D inputs, hue must be a dimension'
+ ' i.e. one of ' + repr(darray.dims))

else:
xplt = darray.transpose(yname, huename)
xplt = darray.transpose(yname, huename, transpose_coords=True)

huelabel = label_from_attrs(darray[huename])
hueplt = darray[huename]
Expand Down Expand Up @@ -798,7 +801,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims)
darray = darray.transpose(*dims, transpose_coords=True)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose()

Expand Down
34 changes: 28 additions & 6 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,14 +1589,14 @@ def test_math_with_coords(self):
assert_identical(expected, actual)

actual = orig[0, :] + orig[:, 0]
assert_identical(expected.T, actual)
assert_identical(expected.transpose(transpose_coords=True), actual)

actual = orig - orig.T
actual = orig - orig.transpose(transpose_coords=True)
expected = DataArray(np.zeros((2, 3)), orig.coords)
assert_identical(expected, actual)

actual = orig.T - orig
assert_identical(expected.T, actual)
actual = orig.transpose(transpose_coords=True) - orig
assert_identical(expected.transpose(transpose_coords=True), actual)

alt = DataArray([1, 1], {'x': [-1, -2], 'c': 'foo', 'd': 555}, 'x')
actual = orig + alt
Expand Down Expand Up @@ -1709,8 +1709,30 @@ def test_stack_nonunique_consistency(self):
assert_identical(expected, actual)

def test_transpose(self):
assert_equal(self.dv.variable.transpose(),
self.dv.transpose().variable)
da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'),
coords={'x': range(3), 'y': range(4), 'z': range(5),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = da.transpose(transpose_coords=False)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords=da.coords)
assert_equal(expected, actual)

actual = da.transpose('z', 'y', 'x', transpose_coords=True)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords={'x': da.x.values, 'y': da.y.values,
'z': da.z.values,
'xy': (('y', 'x'), da.xy.values.T)})
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose('x', 'y')

with pytest.raises(ValueError):
da.transpose(invalid_kwarg=None)

with pytest.warns(FutureWarning):
da.transpose()

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3799,14 +3799,20 @@ def test_dataset_math_errors(self):

def test_dataset_transpose(self):
ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)),
'b': (('y', 'x'), np.random.randn(4, 3))})
'b': (('y', 'x'), np.random.randn(4, 3))},
coords={'x': range(3), 'y': range(4),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = ds.transpose()
expected = ds.apply(lambda x: x.transpose())
expected = Dataset({'a': (('y', 'x'), ds.a.values.T),
'b': (('x', 'y'), ds.b.values.T)},
coords={'x': ds.x.values, 'y': ds.y.values,
'xy': (('y', 'x'), ds.xy.values.T)})
assert_identical(expected, actual)

actual = ds.transpose('x', 'y')
expected = ds.apply(lambda x: x.transpose('x', 'y'))
expected = ds.apply(
lambda x: x.transpose('x', 'y', transpose_coords=True))
assert_identical(expected, actual)

ds = create_test_data()
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def func(obj, dim, new_x):
'y': da['y'],
'x': ('z', xdest.values),
'x2': ('z', func(da['x2'], 'x', xdest))})
assert_allclose(actual, expected.transpose('z', 'y'))
assert_allclose(actual,
expected.transpose('z', 'y', transpose_coords=True))

# xdest is 2d
xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5),
Expand All @@ -162,7 +163,8 @@ def func(obj, dim, new_x):
coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'],
'y': da['y'], 'x': (('z', 'w'), xdest),
'x2': (('z', 'w'), func(da['x2'], 'x', xdest))})
assert_allclose(actual, expected.transpose('z', 'w', 'y'))
assert_allclose(actual,
expected.transpose('z', 'w', 'y', transpose_coords=True))


@pytest.mark.parametrize('case', [3, 4])
Expand Down