From 1ca6fed12271afd8ceaaa1e5000936a3430623d0 Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Sun, 19 Jan 2020 23:03:01 +0100 Subject: [PATCH] WIP: simplify index interpolation to rounding --- .../tomo/checks/check_axes_cone2d_vec_fp.py | 9 +- odl/tomo/backends/astra_cpu.py | 17 +- odl/tomo/backends/astra_setup.py | 55 ++-- odl/tomo/geometry/conebeam.py | 29 +- odl/tomo/geometry/geometry.py | 257 +++++++++++------- odl/tomo/geometry/parallel.py | 54 ++-- 6 files changed, 255 insertions(+), 166 deletions(-) diff --git a/examples/tomo/checks/check_axes_cone2d_vec_fp.py b/examples/tomo/checks/check_axes_cone2d_vec_fp.py index 947b74c174c..0be2dc41755 100644 --- a/examples/tomo/checks/check_axes_cone2d_vec_fp.py +++ b/examples/tomo/checks/check_axes_cone2d_vec_fp.py @@ -70,12 +70,13 @@ proj_data = ray_trafo(phantom) # Axis in this image is x. This corresponds to 0 degrees or index 0. -proj_data.show(indices=[0, None], - title='Projection at 0 degrees ~ Sum along y axis') +proj_data.show( + indices=[0, None], + title='Projection at 0 degrees ~ Sum along y axis' +) fig, ax = plt.subplots() ax.plot(sum_along_y) -ax.set_xlabel('x') -plt.title('Sum along y axis') +ax.set(xlabel="x", title='Sum along y axis') plt.show() # Check axes in geometry axis_sum_y = geometry.det_axis(0) diff --git a/odl/tomo/backends/astra_cpu.py b/odl/tomo/backends/astra_cpu.py index 8e6a7d0abf1..a5284d03b74 100644 --- a/odl/tomo/backends/astra_cpu.py +++ b/odl/tomo/backends/astra_cpu.py @@ -1,4 +1,4 @@ -# Copyright 2014-2019 The ODL contributors +# Copyright 2014-2020 The ODL contributors # # This file is part of ODL. # @@ -17,7 +17,8 @@ astra_algorithm, astra_data, astra_projection_geometry, astra_projector, astra_volume_geometry) from odl.tomo.geometry import ( - DivergentBeamGeometry, Geometry, ParallelBeamGeometry) + ConeVecGeometry, DivergentBeamGeometry, Geometry, ParallelBeamGeometry, + ParallelVecGeometry) from odl.util import writable_array try: @@ -49,20 +50,28 @@ def default_astra_proj_type(geom): - `ParallelBeamGeometry`: ``'linear'`` - `DivergentBeamGeometry`: ``'line_fanflat'`` + - `ParallelVecGeometry`: ``'linear'`` + - `ConeVecGeometry`: ``'line_fanflat'`` In 3D: - `ParallelBeamGeometry`: ``'linear3d'`` - `DivergentBeamGeometry`: ``'linearcone'`` + - `ParallelVecGeometry`: ``'linear3d'`` + - `ConeVecGeometry`: ``'linearcone'`` """ if isinstance(geom, ParallelBeamGeometry): return 'linear' if geom.ndim == 2 else 'linear3d' elif isinstance(geom, DivergentBeamGeometry): return 'line_fanflat' if geom.ndim == 2 else 'linearcone' + elif isinstance(geom, ParallelVecGeometry): + return 'linear' if geom.ndim == 2 else 'linear3d' + elif isinstance(geom, ConeVecGeometry): + return 'line_fanflat' if geom.ndim == 2 else 'linearcone' else: raise TypeError( - 'no default exists for {}, `astra_proj_type` must be given explicitly' - ''.format(type(geom)) + 'no default exists for {}, `astra_proj_type` must be given ' + 'explicitly'.format(type(geom)) ) diff --git a/odl/tomo/backends/astra_setup.py b/odl/tomo/backends/astra_setup.py index db26ab1d52e..b772eb98993 100644 --- a/odl/tomo/backends/astra_setup.py +++ b/odl/tomo/backends/astra_setup.py @@ -1,4 +1,4 @@ -# Copyright 2014-2019 The ODL contributors +# Copyright 2014-2020 The ODL contributors # # This file is part of ODL. # @@ -300,9 +300,11 @@ def astra_projection_geometry(geometry): raise ValueError('non-uniform detector sampling is not supported') # Parallel 2D - if (isinstance(geometry, ParallelBeamGeometry) and - isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and - geometry.ndim == 2): + if ( + isinstance(geometry, ParallelBeamGeometry) + and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) + and geometry.ndim == 2 + ): det_count = geometry.detector.size if astra_supports('par2d_vec_geometry'): vecs = parallel_2d_geom_to_astra_vecs(geometry, coords='ASTRA') @@ -327,9 +329,11 @@ def astra_projection_geometry(geometry): proj_geom = astra.create_proj_geom('parallel_vec', det_count, vecs) # Cone 2D (aka fan beam) - elif (isinstance(geometry, DivergentBeamGeometry) and - isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and - geometry.ndim == 2): + elif ( + isinstance(geometry, DivergentBeamGeometry) + and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) + and geometry.ndim == 2 + ): det_count = geometry.detector.size vecs = cone_2d_geom_to_astra_vecs(geometry, coords='ASTRA') proj_geom = astra.create_proj_geom('fanflat_vec', det_count, vecs) @@ -341,46 +345,55 @@ def astra_projection_geometry(geometry): proj_geom = astra.create_proj_geom('fanflat_vec', det_count, vecs) # Parallel 3D - elif (isinstance(geometry, ParallelBeamGeometry) and - isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and - geometry.ndim == 3): + elif ( + isinstance(geometry, ParallelBeamGeometry) + and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) + and geometry.ndim == 3 + ): # Swap detector axes (see `parallel_3d_geom_to_astra_vecs`) det_row_count = geometry.det_partition.shape[0] det_col_count = geometry.det_partition.shape[1] vecs = parallel_3d_geom_to_astra_vecs(geometry, coords='ASTRA') proj_geom = astra.create_proj_geom( - 'parallel3d_vec', det_row_count, det_col_count, vecs) + 'parallel3d_vec', det_row_count, det_col_count, vecs + ) # Parallel 3D vec elif isinstance(geometry, ParallelVecGeometry) and geometry.ndim == 3: det_row_count = geometry.det_partition.shape[1] det_col_count = geometry.det_partition.shape[0] vecs = vecs_odl_to_astra_coords(geometry.vectors) - proj_geom = astra.create_proj_geom('parallel3d_vec', det_row_count, - det_col_count, vecs) + proj_geom = astra.create_proj_geom( + 'parallel3d_vec', det_row_count, det_col_count, vecs + ) # Cone 3D - elif (isinstance(geometry, DivergentBeamGeometry) and - isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and - geometry.ndim == 3): + elif ( + isinstance(geometry, DivergentBeamGeometry) + and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) + and geometry.ndim == 3 + ): # Swap detector axes (see `conebeam_3d_geom_to_astra_vecs`) det_row_count = geometry.det_partition.shape[0] det_col_count = geometry.det_partition.shape[1] vecs = cone_3d_geom_to_astra_vecs(geometry, coords='ASTRA') proj_geom = astra.create_proj_geom( - 'cone_vec', det_row_count, det_col_count, vecs) + 'cone_vec', det_row_count, det_col_count, vecs + ) # Cone 3D vec elif isinstance(geometry, ConeVecGeometry) and geometry.ndim == 3: det_row_count = geometry.det_partition.shape[1] det_col_count = geometry.det_partition.shape[0] vecs = vecs_odl_to_astra_coords(geometry.vectors) - proj_geom = astra.create_proj_geom('cone_vec', det_row_count, - det_col_count, vecs) + proj_geom = astra.create_proj_geom( + 'cone_vec', det_row_count, det_col_count, vecs + ) else: - raise NotImplementedError('unknown ASTRA geometry type {!r}' - ''.format(geometry)) + raise NotImplementedError( + 'unknown ASTRA geometry type {!r}'.format(geometry) + ) if 'astra' not in geometry.implementation_cache: # Save computed value for later diff --git a/odl/tomo/geometry/conebeam.py b/odl/tomo/geometry/conebeam.py index 499755c4f95..2e0e0043ecd 100644 --- a/odl/tomo/geometry/conebeam.py +++ b/odl/tomo/geometry/conebeam.py @@ -1,4 +1,4 @@ -# Copyright 2014-2019 The ODL contributors +# Copyright 2014-2020 The ODL contributors # # This file is part of ODL. # @@ -32,7 +32,6 @@ class FanBeamGeometry(DivergentBeamGeometry): - """Fan beam (2d cone beam) geometry. The source moves on a circle with radius ``src_radius``, and the @@ -41,7 +40,7 @@ class FanBeamGeometry(DivergentBeamGeometry): radii can be chosen as 0, which corresponds to a stationary source or detector, respectively. - The motion parameter is the 1d rotation angle parameterizing source + The motion parameter is the 1d rotation angle that parametrizes source and detector positions simultaneously. In the standard configuration, the detector is perpendicular to the @@ -592,7 +591,6 @@ def __getitem__(self, indices): class ConeBeamGeometry(DivergentBeamGeometry, AxisOrientedGeometry): - """Cone beam geometry with circular/helical source curve. The source moves along a spiral oriented along a fixed ``axis``, with @@ -1034,7 +1032,7 @@ def det_axes(self, angle): Parameters ---------- - angles : float or `array-like` + angle : float or `array-like` Angle(s) in radians describing the counter-clockwise rotation of the detector around `axis`. @@ -1466,7 +1464,7 @@ def cone_beam_geometry(space, src_radius, det_radius, num_angles=None, # used here is (w/2)/(rs+rd) = rho/rs since both are equal to tan(alpha), # where alpha is the half fan angle. rs = float(src_radius) - if (rs <= rho): + if rs <= rho: raise ValueError('source too close to the object, resulting in ' 'infinite detector for full coverage') rd = float(det_radius) @@ -1508,6 +1506,10 @@ def cone_beam_geometry(space, src_radius, det_radius, num_angles=None, det_max_pt = [w / 2, h / 2] if det_shape is None: det_shape = [num_px_horiz, num_px_vert] + else: + raise ValueError( + '`space.ndim` must be 2 or 3, got {}'.format(space.ndim) + ) fan_angle = 2 * np.arctan(rho / rs) if short_scan: @@ -1636,7 +1638,7 @@ def helical_geometry(space, src_radius, det_radius, num_turns, # used here is (w/2)/(rs+rd) = rho/rs since both are equal to tan(alpha), # where alpha is the half fan angle. rs = float(src_radius) - if (rs <= rho): + if rs <= rho: raise ValueError('source too close to the object, resulting in ' 'infinite detector for full coverage') rd = float(det_radius) @@ -1684,7 +1686,6 @@ def helical_geometry(space, src_radius, det_radius, num_turns, class ConeVecGeometry(VecGeometry): - """Cone beam 2D or 3D geometry defined by a collection of vectors. This geometry gives maximal flexibility for representing locations @@ -1722,6 +1723,8 @@ class ConeVecGeometry(VecGeometry): linear paths. """ + # `rotation_matrix` not implemented; reason: missing + @property def _slice_src(self): """Slice for the source position part of `vectors`.""" @@ -1814,9 +1817,9 @@ def det_to_src(self, mparam, dparam, normalized=True): Parameters ---------- - mpar : `motion_params` element + mparam : `motion_params` element Motion parameter at which to evaluate. - dpar : `det_params` element + dparam : `det_params` element Detector parameter at which to evaluate. normalized : bool, optional If ``True``, return a normalized (unit) vector. @@ -1834,8 +1837,10 @@ def det_to_src(self, mparam, dparam, normalized=True): raise ValueError('`dparam` {} not in the valid range {}' ''.format(dparam, self.det_params)) - vec = (self.src_position(mparam) - - self.det_point_position(mparam, dparam)) + vec = ( + self.src_position(mparam) + - self.det_point_position(mparam, dparam) + ) if normalized: # axis = -1 allows this to be vectorized diff --git a/odl/tomo/geometry/geometry.py b/odl/tomo/geometry/geometry.py index e87cb0ec9f5..b3f5ffbc96c 100644 --- a/odl/tomo/geometry/geometry.py +++ b/odl/tomo/geometry/geometry.py @@ -1,4 +1,4 @@ -# Copyright 2014-2019 The ODL contributors +# Copyright 2014-2020 The ODL contributors # # This file is part of ODL. # @@ -30,7 +30,6 @@ class Geometry(object): - """Abstract geometry class. A geometry is described by @@ -419,7 +418,6 @@ def implementation_cache(self): class DivergentBeamGeometry(Geometry): - """Abstract divergent beam geometry class. A geometry characterized by the presence of a point-like ray source. @@ -462,6 +460,8 @@ def det_to_src(self, angle, dparam, normalized=True): Detector parameter(s) at which to evaluate. If ``det_params.ndim >= 2``, a sequence of that length must be provided. + normalized : bool, optional + If ``True``, return a normalized (unit) vector. Returns ------- @@ -552,8 +552,10 @@ def det_to_src(self, angle, dparam, normalized=True): dparam = tuple(np.array(p, dtype=float, copy=False, ndmin=1) for p in dparam) - det_to_src = (self.src_position(angle) - - self.det_point_position(angle, dparam)) + det_to_src = ( + self.src_position(angle) + - self.det_point_position(angle, dparam) + ) if normalized: det_to_src /= np.linalg.norm(det_to_src, axis=-1, keepdims=True) @@ -565,8 +567,11 @@ def det_to_src(self, angle, dparam, normalized=True): class AxisOrientedGeometry(object): + """Mixin class for 3d geometries oriented along an axis. - """Mixin class for 3d geometries oriented along an axis.""" + Makes use of `Geometry` attributes and should thus only be used for + `Geometry` subclasses. + """ def __init__(self, axis): """Initialize a new instance. @@ -616,8 +621,10 @@ def rotation_matrix(self, angle): """ squeeze_out = (np.shape(angle) == ()) angle = np.array(angle, dtype=float, copy=False, ndmin=1) - if (self.check_bounds and - not is_inside_bounds(angle, self.motion_params)): + if ( + self.check_bounds + and not is_inside_bounds(angle, self.motion_params) + ): raise ValueError('`angle` {} not in the valid range {}' ''.format(angle, self.motion_params)) @@ -629,7 +636,6 @@ def rotation_matrix(self, angle): class VecGeometry(Geometry): - """Abstract 2D or 3D geometry defined by a collection of vectors. This geometry gives maximal flexibility for representing locations @@ -671,6 +677,9 @@ class VecGeometry(Geometry): linear paths. """ + # `rotation_matrix` not implemented; reason: missing + # `det_to_src` not implemented; reason: depends on subclass + def __init__(self, det_shape, vectors): """Initialize a new instance. @@ -733,6 +742,8 @@ def __init__(self, det_shape, vectors): max_pt=(det_cell_sides * det_shape) / 2, shape=det_shape) detector = Flat2dDetector(det_part, axes=[det_u, det_v]) + else: + raise RuntimeError('invalid `ndim`') Geometry.__init__(self, ndim, mpart, detector) @@ -800,8 +811,9 @@ def det_refpoint(self, index): array([ 0., 1.]) >>> geom_2d.det_refpoint(1) array([-1., 0.]) - >>> geom_2d.det_refpoint(0.5) # mean value - array([-0.5, 0.5]) + >>> geom_2d.det_refpoint([0.4, 0.6]) # values at closest indices + array([[ 0., 1.], + [-1., 0.]]) In 3D, columns 3 to 5 (starting at 0) determine the detector reference point, here ``(0, 1, 0)`` and ``(-1, 0, 0)``: @@ -816,35 +828,33 @@ def det_refpoint(self, index): array([ 0., 1., 0.]) >>> geom_3d.det_refpoint(1) array([-1., 0., 0.]) - >>> geom_3d.det_refpoint(0.5) # mean value - array([-0.5, 0.5, 0. ]) + >>> geom_3d.det_refpoint([0.4, 0.6]) # values at closest indices + array([[ 0., 1., 0.], + [-1., 0., 0.]]) """ - if (self.check_bounds and - not is_inside_bounds(index, self.motion_params)): - raise ValueError('`index` {} not in the valid range {}' - ''.format(index, self.motion_params)) - - index = np.array(index, dtype=float, copy=False, ndmin=1) - int_part = index.astype(int) - frac_part = index - int_part - - vecs = np.empty((len(index), self.ndim)) - at_right_bdry = (int_part == self.motion_params.max_pt) - vecs[at_right_bdry, :] = self.vectors[int_part[at_right_bdry], - self._slice_det_center] - - not_at_right_bdry = ~at_right_bdry - if np.any(not_at_right_bdry): - pt_left = self.vectors[int_part[not_at_right_bdry], - self._slice_det_center] - pt_right = self.vectors[int_part[not_at_right_bdry] + 1, - self._slice_det_center] - vecs[not_at_right_bdry, :] = ( - pt_left + - frac_part[not_at_right_bdry, None] * (pt_right - pt_left) + if ( + self.check_bounds + and not is_inside_bounds(index, self.motion_params) + ): + raise ValueError( + '`index` {} not in the valid range {}' + ''.format(index, self.motion_params) ) - return vecs.squeeze() + index_int = np.round(index).astype(int) + if index_int.shape == (): + index_int = index_int[None] + squeeze_index = True + else: + squeeze_index = False + + det_refpts = self.vectors[index_int, self._slice_det_center] + if squeeze_index: + return det_refpts[0] + else: + return det_refpts + + # Overrides implementation in `Geometry` def det_point_position(self, index, dparam): """Return the detector point at ``(index, dparam)``. @@ -886,9 +896,16 @@ def det_point_position(self, index, dparam): >>> # d(1) + 2 * u(1) = (-1, 0) + 2 * (0, 1) >>> geom_2d.det_point_position(1, 2) array([-1., 2.]) - >>> # d(0.5) + 2 * u(0.5) = (-0.5, 0.5) + 2 * (0.5, 0.5) - >>> geom_2d.det_point_position(0.5, 2) - array([ 0.5, 1.5]) + >>> # d(0.4) + 2 * u(0.4) = d(0) + 2 * u(0) + >>> geom_2d.det_point_position(0.4, 2) + array([ 2., 1.]) + + Broadcasting of arguments: + + >>> idcs = np.array([0.4, 0.6])[:, None] + >>> dpar = np.array([2.0])[None, :] + >>> geom_2d.det_point_position(idcs, dpar) + Do the same in 3D, with reference points ``(0, 1, 0)`` and ``(-1, 0, 0)``, and horizontal ``u`` axis vectors ``(1, 0, 0)`` and @@ -911,24 +928,39 @@ def det_point_position(self, index, dparam): >>> # (-1, 0, 0) + 2 * (0, 1, 0) + 1 * (0, 0, 1) >>> geom_3d.det_point_position(1, [2, 1]) array([-1., 2., 1.]) - >>> # d(0.5) + 2 * u(0.5) = (-0.5, 0.5) + 2 * (0.5, 0.5) - >>> geom_3d.det_point_position(0.5, [2, 1]) - array([ 0.5, 1.5, 1. ]) - """ - # TODO: vectorize! + >>> # d(0.4) + 2 * u(0.4) + 1 * u(0.4) = d(0) + 2 * u(0) + 1 * v(0) + >>> geom_3d.det_point_position(0.4, [2, 1]) + array([ 2., 1., 1.]) + + Broadcasting of arguments: - if index not in self.motion_params: - raise ValueError('`index` must be contained in `motion_params` ' - '{}, got {}'.format(self.motion_params, index)) - if dparam not in self.det_params: - raise ValueError('`dparam` must be contained in `det_params` ' - '{}, got {}'.format(self.det_params, dparam)) + >>> idcs = np.array([0.4, 0.6])[:, None] + >>> dpar = np.array([2.0, 1.0])[None, :] + >>> geom_3d.det_point_position(idcs, dpar) + """ + if self.check_bounds: + if not is_inside_bounds(index, self.motion_params): + raise ValueError( + '`index` {} not in the valid range {}' + ''.format(index, self.motion_params) + ) + + if not is_inside_bounds(dparam, self.det_params): + raise ValueError( + '`dparam` {} not in the valid range {}' + ''.format(dparam, self.det_params) + ) + + # TODO: broadcast correctly if self.ndim == 2: det_shift = dparam * self.det_axis(index) elif self.ndim == 3: - det_shift = sum(di * ax - for di, ax in zip(dparam, self.det_axes(index))) + det_shift = sum( + di * ax for di, ax in zip(dparam, self.det_axes(index)) + ) + else: + raise RuntimeError('invalid `ndim`') return self.det_refpoint(index) + det_shift @@ -962,31 +994,38 @@ def det_axis(self, index): array([ 1., 0.]) >>> geom_2d.det_axis(1) array([ 0., 1.]) - >>> geom_2d.det_axis(0.5) # mean value - array([ 0.5, 0.5]) + >>> geom_2d.det_axis([0.4, 0.6]) # values at closest indices + array([[ 1., 0.], + [ 0., 1.]]) """ - # TODO: vectorize! - - if index not in self.motion_params: - raise ValueError('`index` must be contained in `motion_params` ' - '{}, got {}'.format(self.motion_params, index)) + if ( + self.check_bounds + and not is_inside_bounds(index, self.motion_params) + ): + raise ValueError( + '`index` {} not in the valid range {}' + ''.format(index, self.motion_params) + ) if self.ndim != 2: raise NotImplementedError( '`det_axis` only valid for 2D geometries, use `det_axes` ' - 'in 3D') + 'in 3D' + ) - index = float(index) - int_part = int(index) - frac_part = index - int_part - if int_part == self.motion_params.max_pt: - det_u = self.vectors[int_part, self._slice_det_u] + index_int = np.round(index).astype(int) + if index_int.shape == (): + index_int = index_int[None] + squeeze_index = True else: - det_u_left = self.vectors[int_part, self._slice_det_u] - det_u_right = self.vectors[int_part + 1, self._slice_det_u] - det_u = det_u_left + frac_part * (det_u_right - det_u_left) + squeeze_index = False - return det_u + vectors = self.vectors[index_int] + det_us = vectors[:, self._slice_det_u] + if squeeze_index: + return det_us[0] + else: + return det_us def det_axes(self, index): """Return the detector axes at ``index`` (for 2D or 3D geometries). @@ -994,13 +1033,14 @@ def det_axes(self, index): Parameters ---------- index : `motion_params` element - Index of the projection. Non-integer indices result in - interpolated vectors. + Index of the projection. Non-integer indices are rounded to + closest integer. Returns ------- axes : tuple of `numpy.ndarray`, shape ``(ndim,)`` - The detector axes at ``index``. + The detector axes at ``index``, 1 for ``ndim == 2`` and + 2 for ``ndim == 3``. Examples -------- @@ -1019,37 +1059,45 @@ def det_axes(self, index): (array([ 1., 0., 0.]), array([ 0., 0., 1.])) >>> geom_3d.det_axes(1) (array([ 0., 1., 0.]), array([ 0., 0., 1.])) - >>> geom_3d.det_axes(0.5) # mean value - (array([ 0.5, 0.5, 0. ]), array([ 0., 0., 1.])) + >>> axs = geom_3d.det_axes([0.4, 0.6]) # values at closest indices + >>> axs[0] # first axis + array([[ 1., 0., 0.], + [ 0., 1., 0.]]) + >>> axs[1] # second axis + array([[ 0., 0., 1.], + [ 0., 0., 1.]]) """ - # TODO: vectorize! - - if index not in self.motion_params: - raise ValueError('`index` must be contained in `motion_params` ' - '{}, got {}'.format(self.motion_params, index)) - - index = float(index) - int_part = int(index) - frac_part = index - int_part - if int_part == self.motion_params.max_pt: - det_u = self.vectors[int_part, self._slice_det_u] - if self.ndim == 2: - return (det_u,) - elif self.ndim == 3: - det_v = self.vectors[int_part, self._slice_det_v] - return (det_u, det_v) + if ( + self.check_bounds + and not is_inside_bounds(index, self.motion_params) + ): + raise ValueError( + '`index` {} not in the valid range {}' + ''.format(index, self.motion_params) + ) + + index_int = np.round(index).astype(int) + if index_int.shape == (): + index_int = index_int[None] + squeeze_index = True else: - det_u_left = self.vectors[int_part, self._slice_det_u] - det_u_right = self.vectors[int_part + 1, self._slice_det_u] - det_u = det_u_left + frac_part * (det_u_right - det_u_left) - - if self.ndim == 2: - return (det_u,) - elif self.ndim == 3: - det_v_left = self.vectors[int_part, self._slice_det_v] - det_v_right = self.vectors[int_part + 1, self._slice_det_v] - det_v = det_v_left + frac_part * (det_v_right - det_v_left) - return (det_u, det_v) + squeeze_index = False + + vectors = self.vectors[index_int] + if self.ndim == 2: + det_us = vectors[:, self._slice_det_u] + retval_lst = [det_us[0]] if squeeze_index else [det_us] + elif self.ndim == 3: + det_us = vectors[:, self._slice_det_u] + det_vs = vectors[:, self._slice_det_v] + if squeeze_index: + retval_lst = [det_us[0], det_vs[0]] + else: + retval_lst = [det_us, det_vs] + else: + raise RuntimeError('invalid `ndim`') + + return tuple(retval_lst) def __getitem__(self, indices): """Return ``self[indices]``. @@ -1124,8 +1172,9 @@ def __repr__(self): posargs = [self.det_partition.shape, self.vectors] with npy_printoptions(precision=REPR_PRECISION): inner_parts = signature_string_parts(posargs, []) - return repr_string(self.__class__.__name__, inner_parts, - allow_mixed_seps=False) + return repr_string( + self.__class__.__name__, inner_parts, allow_mixed_seps=False + ) if __name__ == '__main__': diff --git a/odl/tomo/geometry/parallel.py b/odl/tomo/geometry/parallel.py index 0c3f4c930cc..d6523359a26 100644 --- a/odl/tomo/geometry/parallel.py +++ b/odl/tomo/geometry/parallel.py @@ -1,4 +1,4 @@ -# Copyright 2014-2019 The ODL contributors +# Copyright 2014-2020 The ODL contributors # # This file is part of ODL. # @@ -30,7 +30,6 @@ class ParallelBeamGeometry(Geometry): - """Abstract parallel beam geometry in 2 or 3 dimensions. Parallel geometries are characterized by a virtual source at @@ -331,7 +330,6 @@ def det_to_src(self, angle, dparam): class Parallel2dGeometry(ParallelBeamGeometry): - """Parallel beam geometry in 2d. The motion parameter is the counter-clockwise rotation angle around @@ -635,8 +633,10 @@ def rotation_matrix(self, angle): """ squeeze_out = (np.shape(angle) == ()) angle = np.array(angle, dtype=float, copy=False, ndmin=1) - if (self.check_bounds and - not is_inside_bounds(angle, self.motion_params)): + if ( + self.check_bounds + and not is_inside_bounds(angle, self.motion_params) + ): raise ValueError('`angle` {} not in the valid range {}' ''.format(angle, self.motion_params)) @@ -669,7 +669,7 @@ def __repr__(self): return '{}(\n{}\n)'.format(self.__class__.__name__, indent(sig_str)) def __getitem__(self, indices): - """Return self[slc] + """Return self[slc]. This is defined by:: @@ -705,7 +705,6 @@ def __getitem__(self, indices): class Parallel3dEulerGeometry(ParallelBeamGeometry): - """Parallel beam geometry in 3d. The motion parameters are two or three Euler angles, and the detector @@ -1044,8 +1043,10 @@ def rotation_matrix(self, angles): angles_in = angles angles = tuple(np.array(angle, dtype=float, copy=False, ndmin=1) for angle in angles) - if (self.check_bounds and - not is_inside_bounds(angles, self.motion_params)): + if ( + self.check_bounds + and not is_inside_bounds(angles, self.motion_params) + ): raise ValueError('`angles` {} not in the valid range ' '{}'.format(angles_in, self.motion_params)) @@ -1079,7 +1080,6 @@ def __repr__(self): class Parallel3dAxisGeometry(ParallelBeamGeometry, AxisOrientedGeometry): - """Parallel beam geometry in 3d with single rotation axis. The motion parameter is the rotation angle around the specified @@ -1474,7 +1474,6 @@ def __getitem__(self, indices): class ParallelVecGeometry(VecGeometry): - """Parallel beam 2D or 3D geometry defined by a collection of vectors. This geometry gives maximal flexibility for representing locations @@ -1512,6 +1511,8 @@ class ParallelVecGeometry(VecGeometry): linear paths. """ + # `rotation_matrix` not implemented; reason: missing + @property def _slice_ray(self): """Slice for the ray direction part of `vectors`.""" @@ -1614,6 +1615,25 @@ def det_to_src(self, index, dparam): (4, 5, 2) """ squeeze_index = (np.shape(index) == ()) + + if self.check_bounds: + if not is_inside_bounds(index, self.motion_params): + raise ValueError( + '`index` {} not in the valid range {}' + ''.format(index, self.motion_params) + ) + + if not is_inside_bounds(dparam, self.det_params): + raise ValueError( + '`dparam` {} not in the valid range {}' + ''.format(dparam, self.det_params) + ) + + index_int = np.round(index).astype(int) + vectors = self.vectors[index_int] + ray_dirs = vectors[self._slice_ray] + det_centers = vectors[self._slice_det_center] + index_in = index index = np.array(index, dtype=float, copy=False, ndmin=1) @@ -1626,14 +1646,6 @@ def det_to_src(self, index, dparam): dparam = tuple(np.array(p, dtype=float, copy=False, ndmin=1) for p in dparam) - if self.check_bounds: - if not is_inside_bounds(index, self.motion_params): - raise ValueError('`index` {} not in the valid range {}' - ''.format(index_in, self.motion_params)) - - if not is_inside_bounds(dparam, self.det_params): - raise ValueError('`dparam` {} not in the valid range {}' - ''.format(dparam_in, self.det_params)) at_max_flat = (index == self.motion_params.max_pt).ravel() @@ -1669,8 +1681,8 @@ def det_to_src(self, index, dparam): print('ray_right shape:', ray_right.shape) print('index_frac_part shape:', index_frac_part.shape) ray_dir[~at_max_flat, ...] = ( - ray_left + - index_frac_part * (ray_right - ray_left)) + ray_left + index_frac_part * (ray_right - ray_left) + ) ray_dir *= -1 / np.linalg.norm(ray_dir, axis=-1, keepdims=True)