Skip to content

Commit

Permalink
Allow expand_dims() method to support inserting/broadcasting dimensio…
Browse files Browse the repository at this point in the history
…ns with size>1 (pydata#2757)

 * dataset.expand_dims() method take dict like object where values represent length of dimensions or coordinates of dimesnsions

 * dataarray.expand_dims() method take dict like object where values represent length of dimensions or coordinates of dimesnsions

 * Add alternative option to passing a dict to the dim argument, which is now an optional kwarg, passing in each new dimension as its own kwarg

 * Add expand_dims enhancement from issue 2710 to whats-new.rst

 * Fix test_dataarray.TestDataArray.test_expand_dims_with_greater_dim_size tests to pass in python 3.5 using ordered dicts instead of regular dicts. This was needed because python 3.5 and earlier did not maintain insertion order for dicts

 * Restrict core logic to use 'dim' as a dict--it will be converted into a dict on entry if it is a str or a sequence of str

 * Don't cast dim values (coords) as a list since IndexVariable/Variable will internally convert it into a numpy.ndarray. So just use IndexVariable((k,), v)

 * TypeErrors should be raised for invalid input types, rather than ValueErrors.

 * Force 'dim' to be OrderedDict for python 3.5
  • Loading branch information
pletchm committed Feb 20, 2019
1 parent 612d390 commit f832672
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 19 deletions.
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Enhancements
be outside the :py:class:`pandas.Timestamp`-valid range (:issue:`2754`). By
`Spencer Clark <https://github.com/spencerkclark>`_.

- Allow ``expand_dims`` method to support inserting/broadcasting dimensions
with size > 1. (:issue:`2710`)
By `Martin Pletcher <https://github.com/pletchm>`_.


Bug fixes
~~~~~~~~~

Expand Down
40 changes: 37 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import sys
import warnings
from collections import OrderedDict

Expand Down Expand Up @@ -1138,7 +1139,7 @@ def swap_dims(self, dims_dict):
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)

def expand_dims(self, dim, axis=None):
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape.
Expand All @@ -1147,21 +1148,54 @@ def expand_dims(self, dim, axis=None):
Parameters
----------
dim : str or sequence of str.
dim : str, sequence of str, dict, or None
Dimensions to include on the new variable.
dimensions are inserted with length 1.
If provided as str or sequence of str, then dimensions are inserted
with length 1. If provided as a dict, then the keys are the new
dimensions and the values are either integers (giving the length of
the new dimensions) or sequence/ndarray (giving the coordinates of
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
dict-like, then it must be an ``OrderedDict`` to ensure that the
order in which the dims are given is maintained.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
**dim_kwargs : int or sequence/ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
coordinates. Note, this is an alternative to passing a dict to the
dim kwarg and will only be used if dim is None. **WARNING** in
python 3.5, the order in which the dim kwargs are given will not be
maintained.
Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
if isinstance(dim, int):
raise TypeError('dim should be str or sequence of strs or dict')
elif isinstance(dim, str):
dim = OrderedDict(((dim, 1),))
elif isinstance(dim, (list, tuple)):
if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')
dim = OrderedDict(((d, 1) for d in dim))

# TODO: get rid of the below code block when python 3.5 is no longer
# supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
if not python36_plus and not_ordereddict:
raise TypeError("dim must be an OrderedDict for python <3.6")
# For python 3.5 dim, must be an OrderedDict. This allows kwargs to be
# used for python 3.5, but insertion order will not be maintained
dim_kwargs = OrderedDict(dim_kwargs)

dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')
ds = self._to_temp_dataset().expand_dims(dim, axis)
return self._from_temp_dataset(ds)

Expand Down
72 changes: 57 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ def swap_dims(self, dims_dict, inplace=None):
return self._replace_with_new_dims(variables, coord_names,
indexes=indexes, inplace=inplace)

def expand_dims(self, dim, axis=None):
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape.
Expand All @@ -2338,26 +2338,52 @@ def expand_dims(self, dim, axis=None):
Parameters
----------
dim : str or sequence of str.
dim : str, sequence of str, dict, or None
Dimensions to include on the new variable.
dimensions are inserted with length 1.
If provided as str or sequence of str, then dimensions are inserted
with length 1. If provided as a dict, then the keys are the new
dimensions and the values are either integers (giving the length of
the new dimensions) or sequence/ndarray (giving the coordinates of
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
dict-like, then it must be an ``OrderedDict`` to ensure that the
order in which the dims are given is maintained.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
the same length list. If axis=None is passed, all the axes will
be inserted to the start of the result array.
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
**dim_kwargs : int or sequence/ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
coordinates. Note, this is an alternative to passing a dict to the
dim kwarg and will only be used if dim is None. **WARNING** in
python 3.5, the order in which the dim kwargs are given will not be
maintained.
Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
if isinstance(dim, int):
raise ValueError('dim should be str or sequence of strs or dict')
raise TypeError('dim should be str or sequence of strs or dict')
elif isinstance(dim, str):
dim = OrderedDict(((dim, 1),))
elif isinstance(dim, (list, tuple)):
if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')
dim = OrderedDict(((d, 1) for d in dim))

# TODO: get rid of the below code block when python 3.5 is no longer
# supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
if not python36_plus and not_ordereddict:
raise TypeError("dim must be an OrderedDict for python <3.6")

dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')

if isinstance(dim, str):
dim = [dim]
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]

Expand All @@ -2376,10 +2402,24 @@ def expand_dims(self, dim, axis=None):
'{dim} already exists as coordinate or'
' variable name.'.format(dim=d))

if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')

variables = OrderedDict()
# If dim is a dict, then ensure that the values are either integers
# or iterables.
for k, v in dim.items():
if hasattr(v, "__iter__"):
# If the value for the new dimension is an iterable, then
# save the coordinates to the variables dict, and set the
# value within the dim dict to the length of the iterable
# for later use.
variables[k] = xr.IndexVariable((k,), v)
self._coord_names.add(k)
dim[k] = len(list(v))
elif isinstance(v, int):
pass # Do nothing if the dimensions value is just an int
else:
raise TypeError('The value of new dimension {k} must be '
'an iterable or an int'.format(k=k))

for k, v in self._variables.items():
if k not in dim:
if k in self._coord_names: # Do not change coordinates
Expand All @@ -2400,11 +2440,13 @@ def expand_dims(self, dim, axis=None):
' values.')
# We need to sort them to make sure `axis` equals to the
# axis positions of the result array.
zip_axis_dim = sorted(zip(axis_pos, dim))
zip_axis_dim = sorted(zip(axis_pos, dim.items()))

all_dims = list(zip(v.dims, v.shape))
for d, c in zip_axis_dim:
all_dims.insert(d, c)
all_dims = OrderedDict(all_dims)

all_dims = list(v.dims)
for a, d in zip_axis_dim:
all_dims.insert(a, d)
variables[k] = v.set_dims(all_dims)
else:
# If dims includes a label of a non-dimension coordinate,
Expand Down
43 changes: 42 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ def test_expand_dims_error(self):
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})

with raises_regex(ValueError, 'dim should be str or'):
with raises_regex(TypeError, 'dim should be str or'):
array.expand_dims(0)
with raises_regex(ValueError, 'lengths of dim and axis'):
# dims and axis argument should be the same length
Expand All @@ -1328,6 +1328,16 @@ def test_expand_dims_error(self):
array.expand_dims(dim=['y', 'z'], axis=[2, -4])
array.expand_dims(dim=['y', 'z'], axis=[2, 3])

array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
with pytest.raises(TypeError):
array.expand_dims(OrderedDict((("new_dim", 3.2),)))

# Attempt to use both dim and kwargs
with pytest.raises(ValueError):
array.expand_dims(OrderedDict((("d", 4),)), e=4)

def test_expand_dims(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
Expand Down Expand Up @@ -1392,6 +1402,37 @@ def test_expand_dims_with_scalar_coordinate(self):
roundtripped = actual.squeeze(['z'], drop=False)
assert_identical(array, roundtripped)

def test_expand_dims_with_greater_dim_size(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0},
attrs={'key': 'entry'})
# For python 3.5 and earlier this has to be an ordered dict, to
# maintain insertion order.
actual = array.expand_dims(
OrderedDict((('y', 2), ('z', 1), ('dim_1', ['a', 'b', 'c']))))

expected_coords = OrderedDict((
('y', [0, 1]), ('z', [1.0]), ('dim_1', ['a', 'b', 'c']),
('x', np.linspace(0, 1, 3)), ('dim_0', range(4))))
expected = DataArray(array.values * np.ones([2, 1, 3, 3, 4]),
coords=expected_coords,
dims=list(expected_coords.keys()),
attrs={'key': 'entry'}
).drop(['y', 'dim_0'])
assert_identical(expected, actual)

# Test with kwargs instead of passing dict to dim arg.
other_way = array.expand_dims(dim_1=['a', 'b', 'c'])

other_way_expected = DataArray(
array.values * np.ones([3, 3, 4]),
coords={'dim_1': ['a', 'b', 'c'],
'x': np.linspace(0, 1, 3),
'dim_0': range(4), 'z': 1.0},
dims=['dim_1', 'x', 'dim_0'],
attrs={'key': 'entry'}).drop('dim_0')
assert_identical(other_way_expected, other_way)

def test_set_index(self):
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names]
coords = {idx.name: ('x', idx) for idx in indexes}
Expand Down
51 changes: 51 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,19 @@ def test_expand_dims_error(self):
with raises_regex(ValueError, 'already exists'):
original.expand_dims(dim=['z'])

original = Dataset({'x': ('a', np.random.randn(3)),
'y': (['b', 'a'], np.random.randn(4, 3)),
'z': ('a', np.random.randn(3))},
coords={'a': np.linspace(0, 1, 3),
'b': np.linspace(0, 1, 4),
'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
with raises_regex(TypeError, 'value of new dimension'):
original.expand_dims(OrderedDict((("d", 3.2),)))

with raises_regex(ValueError, 'both keyword and positional'):
original.expand_dims(OrderedDict((("d", 4),)), e=4)

def test_expand_dims(self):
original = Dataset({'x': ('a', np.random.randn(3)),
'y': (['b', 'a'], np.random.randn(4, 3))},
Expand Down Expand Up @@ -2046,6 +2059,44 @@ def test_expand_dims(self):
roundtripped = actual.squeeze('z')
assert_identical(original, roundtripped)

# Test expanding one dimension to have size > 1 that doesn't have
# coordinates, and also expanding another dimension to have size > 1
# that DOES have coordinates.
actual = original.expand_dims(
OrderedDict((("d", 4), ("e", ["l", "m", "n"]))))

expected = Dataset(
{'x': xr.DataArray(original['x'].values * np.ones([4, 3, 3]),
coords=dict(d=range(4),
e=['l', 'm', 'n'],
a=np.linspace(0, 1, 3)),
dims=['d', 'e', 'a']).drop('d'),
'y': xr.DataArray(original['y'].values * np.ones([4, 3, 4, 3]),
coords=dict(d=range(4),
e=['l', 'm', 'n'],
b=np.linspace(0, 1, 4),
a=np.linspace(0, 1, 3)),
dims=['d', 'e', 'b', 'a']).drop('d')},
coords={'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
assert_identical(actual, expected)

# Test with kwargs instead of passing dict to dim arg.
other_way = original.expand_dims(e=["l", "m", "n"])
other_way_expected = Dataset(
{'x': xr.DataArray(original['x'].values * np.ones([3, 3]),
coords=dict(e=['l', 'm', 'n'],
a=np.linspace(0, 1, 3)),
dims=['e', 'a']),
'y': xr.DataArray(original['y'].values * np.ones([3, 4, 3]),
coords=dict(e=['l', 'm', 'n'],
b=np.linspace(0, 1, 4),
a=np.linspace(0, 1, 3)),
dims=['e', 'b', 'a'])},
coords={'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
assert_identical(other_way_expected, other_way)

def test_set_index(self):
expected = create_test_multiindex()
mindex = expected['x'].to_index()
Expand Down

0 comments on commit f832672

Please sign in to comment.