Skip to content

Commit

Permalink
ENH: add tuple indexing of ProductSpace, closes #965
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Sep 27, 2017
1 parent 695c24d commit adfc19e
Showing 1 changed file with 76 additions and 6 deletions.
82 changes: 76 additions & 6 deletions odl/space/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,21 +540,91 @@ def __hash__(self):
return hash((type(self), self.spaces, self.weighting))

def __getitem__(self, indices):
"""Return ``self[indices]``."""
"""Return ``self[indices]``.
Examples
--------
Integers are used to pick components, slices to pick ranges:
>>> r2, r3, r4 = odl.rn(2), odl.rn(3), odl.rn(4)
>>> pspace = odl.ProductSpace(r2, r3, r4)
>>> pspace[1]
rn(3)
>>> pspace[1:]
ProductSpace(rn(3), rn(4))
With lists, arbitrary components can be stacked together:
>>> pspace[[0, 2, 1, 2]]
ProductSpace(rn(2), rn(4), rn(3), rn(4))
Tuples, i.e. multi-indices, will recursively index higher-order
product spaces. However, remaining indices cannot be passed
down to component spaces that are not product spaces:
>>> pspace2 = odl.ProductSpace(pspace, 3) # 2nd order product space
>>> pspace2
ProductSpace(ProductSpace(rn(2), rn(3), rn(4)), 3)
>>> pspace2[0]
ProductSpace(rn(2), rn(3), rn(4))
>>> pspace2[1, 0]
rn(2)
>>> pspace2[:-1, 0]
ProductSpace(rn(2), 2)
"""
if isinstance(indices, Integral):
return self.spaces[indices]

elif isinstance(indices, slice):
return ProductSpace(*self.spaces[indices], field=self.field)

elif isinstance(indices, tuple):
if len(indices) > 1:
raise ValueError('too many indices: {}'.format(indices))
return ProductSpace(self.spaces[indices[0]], field=self.field)
# Use tuple indexing for recursive product spaces, i.e.,
# pspace[0, 0] == pspace[0][0]
if not indices:
return self
idx = indices[0]
if isinstance(idx, Integral):
# Single integer in tuple, picking that space and passing
# through the rest of the tuple. If the picked space
# is not a product space and there are still indices left,
# raise an error.
space = self.spaces[idx]
rest_indcs = indices[1:]
if not rest_indcs:
return space
elif isinstance(space, ProductSpace):
return space[rest_indcs]
else:
raise IndexError('too many indices for recursive '
'product space: remaining indices '
'{}'.format(rest_indcs))
elif isinstance(idx, slice):
# Doing the same as with single integer with all spaces
# in the slice, but wrapping the result into a ProductSpace.
spaces = self.spaces[idx]
rest_indcs = indices[1:]
if not rest_indcs:
return ProductSpace(*spaces)
elif all(isinstance(space, ProductSpace) for space in spaces):
return ProductSpace(
*(space[rest_indcs] for space in spaces),
field=self.field)
else:
raise IndexError('too many indices for recursive '
'product space: remaining indices '
'{}'.format(rest_indcs))
else:
raise TypeError('index tuple can only contain'
'integers or slices')

elif isinstance(indices, list):
return ProductSpace(*[self.spaces[i] for i in indices],
field=self.field)

else:
raise TypeError('`indices` must be integer, slice or list, got '
'{!r}'.format(indices))
raise TypeError('`indices` must be integer, slice, tuple or '
'or list, got {!r}'.format(indices))

def __str__(self):
"""Return ``str(self)``."""
Expand Down

0 comments on commit adfc19e

Please sign in to comment.