Skip to content

Commit

Permalink
WIP: simplify index interpolation to rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Jan 19, 2020
1 parent 40dadda commit 1ca6fed
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 166 deletions.
9 changes: 5 additions & 4 deletions examples/tomo/checks/check_axes_cone2d_vec_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
proj_data = ray_trafo(phantom)

# Axis in this image is x. This corresponds to 0 degrees or index 0.
proj_data.show(indices=[0, None],
title='Projection at 0 degrees ~ Sum along y axis')
proj_data.show(
indices=[0, None],
title='Projection at 0 degrees ~ Sum along y axis'
)
fig, ax = plt.subplots()
ax.plot(sum_along_y)
ax.set_xlabel('x')
plt.title('Sum along y axis')
ax.set(xlabel="x", title='Sum along y axis')
plt.show()
# Check axes in geometry
axis_sum_y = geometry.det_axis(0)
Expand Down
17 changes: 13 additions & 4 deletions odl/tomo/backends/astra_cpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2014-2019 The ODL contributors
# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
Expand All @@ -17,7 +17,8 @@
astra_algorithm, astra_data, astra_projection_geometry, astra_projector,
astra_volume_geometry)
from odl.tomo.geometry import (
DivergentBeamGeometry, Geometry, ParallelBeamGeometry)
ConeVecGeometry, DivergentBeamGeometry, Geometry, ParallelBeamGeometry,
ParallelVecGeometry)
from odl.util import writable_array

try:
Expand Down Expand Up @@ -49,20 +50,28 @@ def default_astra_proj_type(geom):
- `ParallelBeamGeometry`: ``'linear'``
- `DivergentBeamGeometry`: ``'line_fanflat'``
- `ParallelVecGeometry`: ``'linear'``
- `ConeVecGeometry`: ``'line_fanflat'``
In 3D:
- `ParallelBeamGeometry`: ``'linear3d'``
- `DivergentBeamGeometry`: ``'linearcone'``
- `ParallelVecGeometry`: ``'linear3d'``
- `ConeVecGeometry`: ``'linearcone'``
"""
if isinstance(geom, ParallelBeamGeometry):
return 'linear' if geom.ndim == 2 else 'linear3d'
elif isinstance(geom, DivergentBeamGeometry):
return 'line_fanflat' if geom.ndim == 2 else 'linearcone'
elif isinstance(geom, ParallelVecGeometry):
return 'linear' if geom.ndim == 2 else 'linear3d'
elif isinstance(geom, ConeVecGeometry):
return 'line_fanflat' if geom.ndim == 2 else 'linearcone'
else:
raise TypeError(
'no default exists for {}, `astra_proj_type` must be given explicitly'
''.format(type(geom))
'no default exists for {}, `astra_proj_type` must be given '
'explicitly'.format(type(geom))
)


Expand Down
55 changes: 34 additions & 21 deletions odl/tomo/backends/astra_setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2014-2019 The ODL contributors
# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
Expand Down Expand Up @@ -300,9 +300,11 @@ def astra_projection_geometry(geometry):
raise ValueError('non-uniform detector sampling is not supported')

# Parallel 2D
if (isinstance(geometry, ParallelBeamGeometry) and
isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and
geometry.ndim == 2):
if (
isinstance(geometry, ParallelBeamGeometry)
and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector))
and geometry.ndim == 2
):
det_count = geometry.detector.size
if astra_supports('par2d_vec_geometry'):
vecs = parallel_2d_geom_to_astra_vecs(geometry, coords='ASTRA')
Expand All @@ -327,9 +329,11 @@ def astra_projection_geometry(geometry):
proj_geom = astra.create_proj_geom('parallel_vec', det_count, vecs)

# Cone 2D (aka fan beam)
elif (isinstance(geometry, DivergentBeamGeometry) and
isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and
geometry.ndim == 2):
elif (
isinstance(geometry, DivergentBeamGeometry)
and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector))
and geometry.ndim == 2
):
det_count = geometry.detector.size
vecs = cone_2d_geom_to_astra_vecs(geometry, coords='ASTRA')
proj_geom = astra.create_proj_geom('fanflat_vec', det_count, vecs)
Expand All @@ -341,46 +345,55 @@ def astra_projection_geometry(geometry):
proj_geom = astra.create_proj_geom('fanflat_vec', det_count, vecs)

# Parallel 3D
elif (isinstance(geometry, ParallelBeamGeometry) and
isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and
geometry.ndim == 3):
elif (
isinstance(geometry, ParallelBeamGeometry)
and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector))
and geometry.ndim == 3
):
# Swap detector axes (see `parallel_3d_geom_to_astra_vecs`)
det_row_count = geometry.det_partition.shape[0]
det_col_count = geometry.det_partition.shape[1]
vecs = parallel_3d_geom_to_astra_vecs(geometry, coords='ASTRA')
proj_geom = astra.create_proj_geom(
'parallel3d_vec', det_row_count, det_col_count, vecs)
'parallel3d_vec', det_row_count, det_col_count, vecs
)

# Parallel 3D vec
elif isinstance(geometry, ParallelVecGeometry) and geometry.ndim == 3:
det_row_count = geometry.det_partition.shape[1]
det_col_count = geometry.det_partition.shape[0]
vecs = vecs_odl_to_astra_coords(geometry.vectors)
proj_geom = astra.create_proj_geom('parallel3d_vec', det_row_count,
det_col_count, vecs)
proj_geom = astra.create_proj_geom(
'parallel3d_vec', det_row_count, det_col_count, vecs
)

# Cone 3D
elif (isinstance(geometry, DivergentBeamGeometry) and
isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector)) and
geometry.ndim == 3):
elif (
isinstance(geometry, DivergentBeamGeometry)
and isinstance(geometry.detector, (Flat1dDetector, Flat2dDetector))
and geometry.ndim == 3
):
# Swap detector axes (see `conebeam_3d_geom_to_astra_vecs`)
det_row_count = geometry.det_partition.shape[0]
det_col_count = geometry.det_partition.shape[1]
vecs = cone_3d_geom_to_astra_vecs(geometry, coords='ASTRA')
proj_geom = astra.create_proj_geom(
'cone_vec', det_row_count, det_col_count, vecs)
'cone_vec', det_row_count, det_col_count, vecs
)

# Cone 3D vec
elif isinstance(geometry, ConeVecGeometry) and geometry.ndim == 3:
det_row_count = geometry.det_partition.shape[1]
det_col_count = geometry.det_partition.shape[0]
vecs = vecs_odl_to_astra_coords(geometry.vectors)
proj_geom = astra.create_proj_geom('cone_vec', det_row_count,
det_col_count, vecs)
proj_geom = astra.create_proj_geom(
'cone_vec', det_row_count, det_col_count, vecs
)

else:
raise NotImplementedError('unknown ASTRA geometry type {!r}'
''.format(geometry))
raise NotImplementedError(
'unknown ASTRA geometry type {!r}'.format(geometry)
)

if 'astra' not in geometry.implementation_cache:
# Save computed value for later
Expand Down
29 changes: 17 additions & 12 deletions odl/tomo/geometry/conebeam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2014-2019 The ODL contributors
# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
Expand Down Expand Up @@ -32,7 +32,6 @@


class FanBeamGeometry(DivergentBeamGeometry):

"""Fan beam (2d cone beam) geometry.
The source moves on a circle with radius ``src_radius``, and the
Expand All @@ -41,7 +40,7 @@ class FanBeamGeometry(DivergentBeamGeometry):
radii can be chosen as 0, which corresponds to a stationary source
or detector, respectively.
The motion parameter is the 1d rotation angle parameterizing source
The motion parameter is the 1d rotation angle that parametrizes source
and detector positions simultaneously.
In the standard configuration, the detector is perpendicular to the
Expand Down Expand Up @@ -592,7 +591,6 @@ def __getitem__(self, indices):


class ConeBeamGeometry(DivergentBeamGeometry, AxisOrientedGeometry):

"""Cone beam geometry with circular/helical source curve.
The source moves along a spiral oriented along a fixed ``axis``, with
Expand Down Expand Up @@ -1034,7 +1032,7 @@ def det_axes(self, angle):
Parameters
----------
angles : float or `array-like`
angle : float or `array-like`
Angle(s) in radians describing the counter-clockwise rotation
of the detector around `axis`.
Expand Down Expand Up @@ -1466,7 +1464,7 @@ def cone_beam_geometry(space, src_radius, det_radius, num_angles=None,
# used here is (w/2)/(rs+rd) = rho/rs since both are equal to tan(alpha),
# where alpha is the half fan angle.
rs = float(src_radius)
if (rs <= rho):
if rs <= rho:
raise ValueError('source too close to the object, resulting in '
'infinite detector for full coverage')
rd = float(det_radius)
Expand Down Expand Up @@ -1508,6 +1506,10 @@ def cone_beam_geometry(space, src_radius, det_radius, num_angles=None,
det_max_pt = [w / 2, h / 2]
if det_shape is None:
det_shape = [num_px_horiz, num_px_vert]
else:
raise ValueError(
'`space.ndim` must be 2 or 3, got {}'.format(space.ndim)
)

fan_angle = 2 * np.arctan(rho / rs)
if short_scan:
Expand Down Expand Up @@ -1636,7 +1638,7 @@ def helical_geometry(space, src_radius, det_radius, num_turns,
# used here is (w/2)/(rs+rd) = rho/rs since both are equal to tan(alpha),
# where alpha is the half fan angle.
rs = float(src_radius)
if (rs <= rho):
if rs <= rho:
raise ValueError('source too close to the object, resulting in '
'infinite detector for full coverage')
rd = float(det_radius)
Expand Down Expand Up @@ -1684,7 +1686,6 @@ def helical_geometry(space, src_radius, det_radius, num_turns,


class ConeVecGeometry(VecGeometry):

"""Cone beam 2D or 3D geometry defined by a collection of vectors.
This geometry gives maximal flexibility for representing locations
Expand Down Expand Up @@ -1722,6 +1723,8 @@ class ConeVecGeometry(VecGeometry):
linear paths.
"""

# `rotation_matrix` not implemented; reason: missing

@property
def _slice_src(self):
"""Slice for the source position part of `vectors`."""
Expand Down Expand Up @@ -1814,9 +1817,9 @@ def det_to_src(self, mparam, dparam, normalized=True):
Parameters
----------
mpar : `motion_params` element
mparam : `motion_params` element
Motion parameter at which to evaluate.
dpar : `det_params` element
dparam : `det_params` element
Detector parameter at which to evaluate.
normalized : bool, optional
If ``True``, return a normalized (unit) vector.
Expand All @@ -1834,8 +1837,10 @@ def det_to_src(self, mparam, dparam, normalized=True):
raise ValueError('`dparam` {} not in the valid range {}'
''.format(dparam, self.det_params))

vec = (self.src_position(mparam) -
self.det_point_position(mparam, dparam))
vec = (
self.src_position(mparam)
- self.det_point_position(mparam, dparam)
)

if normalized:
# axis = -1 allows this to be vectorized
Expand Down
Loading

0 comments on commit 1ca6fed

Please sign in to comment.