From ddb294a41265a690d1a009a8e1ad96462ded2f06 Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Sat, 18 Mar 2017 01:27:38 +0100 Subject: [PATCH] ENH: add tuple indexing of ProductSpace, closes #965 --- odl/space/pspace.py | 82 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 6 deletions(-) diff --git a/odl/space/pspace.py b/odl/space/pspace.py index 0ce65e6a015..1a258c74d6e 100644 --- a/odl/space/pspace.py +++ b/odl/space/pspace.py @@ -540,21 +540,91 @@ def __hash__(self): return hash((type(self), self.spaces, self.weighting)) def __getitem__(self, indices): - """Return ``self[indices]``.""" + """Return ``self[indices]``. + + Examples + -------- + Integers are used to pick components, slices to pick ranges: + + >>> r2, r3, r4 = odl.rn(2), odl.rn(3), odl.rn(4) + >>> pspace = odl.ProductSpace(r2, r3, r4) + >>> pspace[1] + rn(3) + >>> pspace[1:] + ProductSpace(rn(3), rn(4)) + + With lists, arbitrary components can be stacked together: + + >>> pspace[[0, 2, 1, 2]] + ProductSpace(rn(2), rn(4), rn(3), rn(4)) + + Tuples, i.e. multi-indices, will recursively index higher-order + product spaces. However, remaining indices cannot be passed + down to component spaces that are not product spaces: + + >>> pspace2 = odl.ProductSpace(pspace, 3) # 2nd order product space + >>> pspace2 + ProductSpace(ProductSpace(rn(2), rn(3), rn(4)), 3) + >>> pspace2[0] + ProductSpace(rn(2), rn(3), rn(4)) + >>> pspace2[1, 0] + rn(2) + >>> pspace2[:-1, 0] + ProductSpace(rn(2), 2) + """ if isinstance(indices, Integral): return self.spaces[indices] + elif isinstance(indices, slice): return ProductSpace(*self.spaces[indices], field=self.field) + elif isinstance(indices, tuple): - if len(indices) > 1: - raise ValueError('too many indices: {}'.format(indices)) - return ProductSpace(self.spaces[indices[0]], field=self.field) + # Use tuple indexing for recursive product spaces, i.e., + # pspace[0, 0] == pspace[0][0] + if not indices: + return self + idx = indices[0] + if isinstance(idx, Integral): + # Single integer in tuple, picking that space and passing + # through the rest of the tuple. If the picked space + # is not a product space and there are still indices left, + # raise an error. + space = self.spaces[idx] + rest_indcs = indices[1:] + if not rest_indcs: + return space + elif isinstance(space, ProductSpace): + return space[rest_indcs] + else: + raise IndexError('too many indices for recursive ' + 'product space: remaining indices ' + '{}'.format(rest_indcs)) + elif isinstance(idx, slice): + # Doing the same as with single integer with all spaces + # in the slice, but wrapping the result into a ProductSpace. + spaces = self.spaces[idx] + rest_indcs = indices[1:] + if not rest_indcs: + return ProductSpace(*spaces) + elif all(isinstance(space, ProductSpace) for space in spaces): + return ProductSpace( + *(space[rest_indcs] for space in spaces), + field=self.field) + else: + raise IndexError('too many indices for recursive ' + 'product space: remaining indices ' + '{}'.format(rest_indcs)) + else: + raise TypeError('index tuple can only contain' + 'integers or slices') + elif isinstance(indices, list): return ProductSpace(*[self.spaces[i] for i in indices], field=self.field) + else: - raise TypeError('`indices` must be integer, slice or list, got ' - '{!r}'.format(indices)) + raise TypeError('`indices` must be integer, slice, tuple or ' + 'or list, got {!r}'.format(indices)) def __str__(self): """Return ``str(self)``."""