diff --git a/odl/space/pspace.py b/odl/space/pspace.py index 3ed957779b6..c872366abfc 100644 --- a/odl/space/pspace.py +++ b/odl/space/pspace.py @@ -255,25 +255,25 @@ def shape(self): >>> pspace2 = odl.ProductSpace(pspace, 3) >>> pspace2.shape (3, 2) - - If the space is a "pure" product space, shape recurses all the way - into the components: - >>> r2_2 = odl.ProductSpace(r2, 3) >>> r2_2.shape - (3, 2) + (3,) """ if len(self) == 0: return () - elif self.is_power_space: + + shape = [len(self)] + spaces = self.spaces + is_power_space = self.is_power_space + while is_power_space: try: - sub_shape = self[0].shape + is_power_space = spaces[0].is_power_space except AttributeError: - sub_shape = () - else: - sub_shape = () + break + spaces = spaces[0].spaces + shape.append(len(spaces)) - return (len(self),) + sub_shape + return tuple(shape) @property def size(self): @@ -292,27 +292,9 @@ def size(self): >>> pspace2.size 6 """ - return (0 if self.shape == () else - int(np.prod(self.shape, dtype='int64'))) - - @property - def dtype(self): - """The data type of this space. - - This is only well defined if all subspaces have the same dtype. - - Raises - ------ - AttributeError - If any of the subspaces does not implement `dtype` or if the dtype - of the subspaces does not match. - """ - dtypes = [space.dtype for space in self.spaces] - - if all(dtype == dtypes[0] for dtype in dtypes): - return dtypes[0] - else: - raise AttributeError("`dtype`'s of subspaces not equal") + return ( + 0 if self.shape == () else int(np.prod(self.shape, dtype='int64')) + ) # --- Analytic properties @@ -349,14 +331,14 @@ def complex_space(self): return ProductSpace(*[space.complex_space for space in self.spaces]) def astype(self, dtype): - """Return a copy of this space with new ``dtype``. + """Return a copy of this space with subspaces of given ``dtype``. Parameters ---------- - dtype : - Scalar data type of the returned space. Can be provided - in any way the `numpy.dtype` constructor understands, e.g. - as built-in type or as a string. Data types with non-trivial + dtype + Scalar data type of the constituents of the returned space. Can + be provided in any way the `numpy.dtype` constructor understands, + e.g. as built-in type or as a string. Data types with non-trivial shapes are not allowed. Returns @@ -369,22 +351,28 @@ def astype(self, dtype): raise ValueError('`None` is not a valid data type') dtype = np.dtype(dtype) - current_dtype = getattr(self, 'dtype', object) + current_dtypes = [ + getattr(space, 'dtype', None) for space in self.spaces + ] - if dtype == current_dtype: + if all(dt is not None and dt == dtype for dt in current_dtypes): return self else: - return ProductSpace(*[space.astype(dtype) - for space in self.spaces]) + return ProductSpace( + *[space.astype(dtype) for space in self.spaces] + ) # --- Element handling - def _flatten(self, spaces, inputs=None): - for n in self.shape[:-1]: + def _flatten(self, inputs=None): + spaces = self.spaces + size = 1 + for n in self.shape: + size *= n try: - spaces = sum((spaces[i].spaces for i in range(n)), ()) + spaces = sum((spaces[i].spaces for i in range(size)), ()) if inputs is not None: - inputs = sum((tuple(inputs[i]) for i in range(n)), ()) + inputs = sum((tuple(inputs[i]) for i in range(size)), ()) except AttributeError: break if inputs is None: @@ -441,12 +429,12 @@ def element(self, inp=None, cast=True): return inp if inp is None: - flat_spaces = self._flatten(self.spaces) + flat_spaces = self._flatten() flat_inp = [space.element() for space in flat_spaces] else: - flat_spaces, flat_inp = self._flatten(self.spaces, inp) + flat_spaces, flat_inp = self._flatten(inp) - if len(flat_inp) != self.size: + if len(flat_inp) != len(flat_spaces): raise ValueError( "flattened size {} of input {!r} does not match this space's " 'size {}'.format(len(flat_inp), inp, self.size) @@ -467,12 +455,27 @@ def element(self, inp=None, cast=True): # Note: the array must be created in advance, since otherwise NumPy # may still try to loop over the inputs. # See https://github.com/numpy/numpy/issues/12479 - # TODO(kohr-h): remove when above issue is resolved - ret = np.empty(self.size, dtype=object) + # TODO(kohr-h): maybe remove when above issue is resolved + ret = np.empty(len(flat_spaces), dtype=object) for i, xi in enumerate(flat_inp): ret[i] = xi return ret.reshape(self.shape) + def to_scalar_dtype(self, elem): + """Convert power space element to NumPy array with scalar dtype.""" + + def comp_list_map(func): + def nested(x): + return list(map(func, x)) + + return nested + + nested_list = list + for _ in self.shape[1:]: + nested_list = comp_list_map(nested_list) + + return np.array(nested_list(elem)) + def zero(self): """Create the zero element of the product space. @@ -532,7 +535,7 @@ def __contains__(self, other): # TODO: doctest if not isinstance(other, np.ndarray): return False - return all(p in spc for p, spc in zip(other, self.spaces)) + return all(oi in spc for oi, spc in zip(other, self.spaces)) # --- Space functions diff --git a/odl/test/space/pspace_test.py b/odl/test/space/pspace_test.py index 3118db2c89a..2fde7d22343 100644 --- a/odl/test/space/pspace_test.py +++ b/odl/test/space/pspace_test.py @@ -170,8 +170,8 @@ def test_pspace_basic_properties(): # Power space pspace = odl.ProductSpace(r3, 2) assert len(pspace) == 2 - assert pspace.shape == (2, 3) - assert pspace.size == 6 + assert pspace.shape == (2,) + assert pspace.size == 2 assert pspace.spaces[0] == pspace.spaces[1] == r3 assert pspace.is_power_space assert not pspace.is_weighted @@ -208,24 +208,24 @@ def test_pspace_equality(exponent): def test_pspace_element(): """Test element creation in product spaces.""" - H = odl.rn(2) - HxH = odl.ProductSpace(H, H) - elem = HxH.element([[1, 2], [3, 4]]) - assert elem in HxH + r2 = odl.rn(2) + pspace = odl.ProductSpace(r2, r2) + x = pspace.element([[1, 2], [3, 4]]) + assert x in pspace - # wrong length + # Wrong length with pytest.raises(ValueError): - HxH.element([[1, 2]]) + pspace.element([[1, 2]]) with pytest.raises(ValueError): - HxH.element([[1, 2], [3, 4], [5, 6]]) + pspace.element([[1, 2], [3, 4], [5, 6]]) - # wrong length of subspace element + # Wrong length of subspace element with pytest.raises(ValueError): - HxH.element([[1, 2, 3], [4, 5]]) + pspace.element([[1, 2, 3], [4, 5]]) with pytest.raises(ValueError): - HxH.element([[1, 2], [3, 4, 5]]) + pspace.element([[1, 2], [3, 4, 5]]) def test_pspace_lincomb(): @@ -338,7 +338,7 @@ def test_power_shape(): assert empty.size == empty2.size == 0 r2_3 = odl.ProductSpace(r2, 3) - _test_shape(r2_3, (3, 2)) + _test_shape(r2_3, (3,)) r2xr3 = odl.ProductSpace(r2, r3) _test_shape(r2xr3, (2,))