From 54084d1fa31aa58b0a62e77987e0c046275aadab Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Mon, 26 Aug 2024 17:28:48 -0600 Subject: [PATCH] Work on comments --- CHANGES.rst | 2 +- examples/scripts/ct_large_projection.py | 4 +- examples/scripts/ct_multi_tv_admm.py | 6 +- .../scripts/ct_projector_comparison_2d.py | 4 +- .../scripts/ct_projector_comparison_3d.py | 8 +- examples/scripts/ct_tv_admm.py | 4 +- scico/linop/__init__.py | 3 +- scico/linop/xray/__init__.py | 4 +- scico/linop/xray/_xray.py | 310 ++++++++++-------- scico/linop/xray/astra.py | 16 +- scico/test/linop/xray/test_astra.py | 29 +- scico/test/linop/xray/test_xray.py | 67 +++- 12 files changed, 280 insertions(+), 177 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a32c38544..e50d09129 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -32,7 +32,7 @@ Version 0.0.5 (2023-12-18) • New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes - to ``XRayTransform``. + to ``sform``. • Rename ``AbelProjector`` to ``AbelTransform``. • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. • Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and diff --git a/examples/scripts/ct_large_projection.py b/examples/scripts/ct_large_projection.py index d7db963f6..9281b68e6 100644 --- a/examples/scripts/ct_large_projection.py +++ b/examples/scripts/ct_large_projection.py @@ -18,7 +18,7 @@ import jax from scico.examples import create_block_phantom -from scico.linop import Parallel3dProjector, XRayTransform +from scico.linop import Parallel3dProjector N = 1000 num_views = 10 @@ -36,7 +36,7 @@ ) -H = XRayTransform(Parallel3dProjector(in_shape, matrices, det_shape)) +H = Parallel3dProjector(in_shape, matrices, det_shape) proj = H @ x jax.block_until_ready(proj) diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 87fc86865..169b9ebae 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -27,7 +27,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir +from scico.linop.xray import Parallel2dProjector, astra, svmbir from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -54,9 +54,7 @@ "svmbir": svmbir.XRayTransform( x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing ), # svmbir - "scico": XRayTransform( - Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing) - ), # scico + "scico": Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico } diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py index 8171d229b..193526268 100644 --- a/examples/scripts/ct_projector_comparison_2d.py +++ b/examples/scripts/ct_projector_comparison_2d.py @@ -22,7 +22,7 @@ import scico.linop.xray.astra as astra from scico import plot -from scico.linop import Parallel2dProjector, XRayTransform +from scico.linop import Parallel2dProjector from scico.util import Timer """ @@ -46,7 +46,7 @@ projectors = {} timer.start("scico_init") -projectors["scico"] = XRayTransform(Parallel2dProjector((N, N), angles)) +projectors["scico"] = Parallel2dProjector((N, N), angles) timer.stop("scico_init") timer.start("astra_init") diff --git a/examples/scripts/ct_projector_comparison_3d.py b/examples/scripts/ct_projector_comparison_3d.py index d4752a333..6fd36401d 100644 --- a/examples/scripts/ct_projector_comparison_3d.py +++ b/examples/scripts/ct_projector_comparison_3d.py @@ -64,9 +64,6 @@ y_scico = H_scico @ x jax.block_until_ready(y_scico) -with ContextTimer(timer_scico, "first_fwd"): - y_scico = H_scico @ x - with ContextTimer(timer_scico, "avg_fwd"): for _ in range(num_repeats): y_scico = H_scico @ x @@ -99,10 +96,7 @@ with ContextTimer(timer_astra, "first_fwd"): y_astra_from_scico = H_astra_from_scico @ x - jax.block_until_ready(y_scico) - -with ContextTimer(timer_astra, "first_fwd"): - y_astra_from_scico = H_astra_from_scico @ x + jax.block_until_ready(y_astra_from_scico) with ContextTimer(timer_astra, "avg_fwd"): for _ in range(num_repeats): diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py index 4b77eeb07..da80e180c 100644 --- a/examples/scripts/ct_tv_admm.py +++ b/examples/scripts/ct_tv_admm.py @@ -29,7 +29,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.xray import Parallel2dProjector, XRayTransform +from scico.linop.xray import Parallel2dProjector from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -46,7 +46,7 @@ """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles -A = XRayTransform(Parallel2dProjector((N, N), angles)) # CT projection operator +A = Parallel2dProjector((N, N), angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index b0149930d..8c465c81a 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -25,7 +25,7 @@ from ._matrix import MatrixOperator from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes from ._util import jacobian, operator_norm, power_iteration, valid_adjoint -from .xray import Parallel2dProjector, Parallel3dProjector, XRayTransform +from .xray import Parallel2dProjector, Parallel3dProjector __all__ = [ "CircularConvolve", @@ -51,7 +51,6 @@ "Sum", "Transpose", "LinearOperator", - "XRayTransform", "Parallel2dProjector", "Parallel3dProjector", "ComposedLinearOperator", diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py index db6afe38c..f00a0a5e3 100644 --- a/scico/linop/xray/__init__.py +++ b/scico/linop/xray/__init__.py @@ -45,12 +45,10 @@ """ -import sys -from ._xray import Parallel2dProjector, Parallel3dProjector, XRayTransform +from ._xray import Parallel2dProjector, Parallel3dProjector __all__ = [ "Parallel2dProjector", "Parallel3dProjector", - "XRayTransform", ] diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index e252ecb06..e58ffd7f8 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -25,30 +25,7 @@ from .._linop import LinearOperator -class XRayTransform(LinearOperator): - """X-ray transform linear operator. - - Wrap an X-ray projector object in a SCICO :class:`LinearOperator`. - """ - - def __init__(self, projector): - r""" - Args: - projector: instance of an X-ray projector object to wrap, - currently the only option is - :class:`Parallel2dProjector` - """ - self.projector = projector - - super().__init__( - input_shape=projector.input_shape, - output_shape=projector.output_shape, - eval_fn=projector.project, - adj_fn=projector.back_project, - ) - - -class Parallel2dProjector: +class Parallel2dProjector(LinearOperator): """Parallel ray, single axis, 2D X-ray projector. This implementation approximates the projection of each rectangular @@ -131,118 +108,147 @@ def __init__( self.y0 = y0 self.dy = 1.0 + super().__init__( + input_shape=self.input_shape, + output_shape=self.output_shape, + eval_fn=self.project, + adj_fn=self.back_project, + ) + def project(self, im): """Compute X-ray projection.""" - return _project(im, self.x0, self.dx, self.y0, self.ny, self.angles) + return Parallel2dProjector._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) def back_project(self, y): """Compute X-ray back projection""" - return _back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) - - -@partial(jax.jit, static_argnames=["ny"]) -def _project(im, x0, dx, y0, ny, angles): - r""" - Args: - im: Input array, (M, N). - x0: (x, y) position of the corner of the pixel im[0,0]. - dx: Pixel side length in x- and y-direction. Units are such - that the detector bins have length 1.0. - y0: Location of the edge of the first detector bin. - ny: Number of detector bins. - angles: (num_angles,) array of angles in radians. Pixels are - projected onto unit vectors pointing in these directions. - """ - nx = im.shape - inds, weights = _calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) - - y = ( - jnp.zeros((len(angles), ny)) - .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] - .add(im * weights) - ) - - y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights)) - - return y - - -@partial(jax.jit, static_argnames=["nx"]) -def _back_project(y, x0, dx, nx, y0, angles): - r""" - Args: - y: Input projection, (num_angles, N). - x0: (x, y) position of the corner of the pixel im[0,0]. - dx: Pixel side length in x- and y-direction. Units are such - that the detector bins have length 1.0. - nx: Shape of back projection. - y0: Location of the edge of the first detector bin. - angles: (num_angles,) array of angles in radians. Pixels are - projected onto units vectors pointing in these directions. - """ - ny = y.shape[1] - inds, weights = _calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) + return Parallel2dProjector._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) - # the idea: [y[0, inds[0]], y[1, inds[1]], ...] - HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0) - HTy = HTy + jnp.sum( - y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0 - ) + @staticmethod + @partial(jax.jit, static_argnames=["ny"]) + def _project(im, x0, dx, y0, ny, angles): + r""" + Args: + im: Input array, (M, N). + x0: (x, y) position of the corner of the pixel im[0,0]. + dx: Pixel side length in x- and y-direction. Units are such + that the detector bins have length 1.0. + y0: Location of the edge of the first detector bin. + ny: Number of detector bins. + angles: (num_angles,) array of angles in radians. Pixels are + projected onto unit vectors pointing in these directions. + """ + nx = im.shape + inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0) + # Handle out of bounds indices. In the .at call, inds >= y0 are + # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. + inds = jnp.where(inds >= 0, inds, ny) + + y = ( + jnp.zeros((len(angles), ny)) + .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] + .add(im * weights) + ) - return HTy + y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights)) + return y -@partial(jax.jit, static_argnames=["nx"]) -@partial(jax.vmap, in_axes=(None, None, None, 0, None)) -def _calc_weights(x0, dx, nx, angle, y0): - """ + @staticmethod + @partial(jax.jit, static_argnames=["nx"]) + def _back_project(y, x0, dx, nx, y0, angles): + r""" + Args: + y: Input projection, (num_angles, N). + x0: (x, y) position of the corner of the pixel im[0,0]. + dx: Pixel side length in x- and y-direction. Units are such + that the detector bins have length 1.0. + nx: Shape of back projection. + y0: Location of the edge of the first detector bin. + angles: (num_angles,) array of angles in radians. Pixels are + projected onto units vectors pointing in these directions. + """ + ny = y.shape[1] + inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0) + # Handle out of bounds indices. In the .at call, inds >= y0 are + # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. + inds = jnp.where(inds >= 0, inds, ny) + + # the idea: [y[0, inds[0]], y[1, inds[1]], ...] + HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0) + HTy = HTy + jnp.sum( + y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0 + ) - Args: - x0: Location of the corner of the pixel im[0,0]. - dx: Pixel side length in x- and y-direction. Units are such - that the detector bins have length 1.0. - nx: Input image shape. - angle: (num_angles,) array of angles in radians. Pixels are - projected onto units vectors pointing in these directions. - (This argument is `vmap`ed.) - y0: Location of the edge of the first detector bin. - """ - u = [jnp.cos(angle), jnp.sin(angle)] - Px0 = x0[0] * u[0] + x0[1] * u[1] - y0 - Pdx = [dx[0] * u[0], dx[1] * u[1]] - Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]])) + return HTy - Px = ( - Pxmin - + Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1) - + Pdx[1] * jnp.arange(nx[1]).reshape(1, -1) - ) + @staticmethod + @partial(jax.jit, static_argnames=["nx"]) + @partial(jax.vmap, in_axes=(None, None, None, 0, None)) + def _calc_weights(x0, dx, nx, angle, y0): + """ - # detector bin inds - inds = jnp.floor(Px).astype(int) + Args: + x0: Location of the corner of the pixel im[0,0]. + dx: Pixel side length in x- and y-direction. Units are such + that the detector bins have length 1.0. + nx: Input image shape. + angle: (num_angles,) array of angles in radians. Pixels are + projected onto units vectors pointing in these directions. + (This argument is `vmap`ed.) + y0: Location of the edge of the first detector bin. + """ + u = [jnp.cos(angle), jnp.sin(angle)] + Px0 = x0[0] * u[0] + x0[1] * u[1] - y0 + Pdx = [dx[0] * u[0], dx[1] * u[1]] + Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]])) + + Px = ( + Pxmin + + Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1) + + Pdx[1] * jnp.arange(nx[1]).reshape(1, -1) + ) + + # detector bin inds + inds = jnp.floor(Px).astype(int) + + # weights + Pdx = jnp.array(u) * jnp.array(dx) + diag1 = jnp.abs(Pdx[0] + Pdx[1]) + diag2 = jnp.abs(Pdx[0] - Pdx[1]) + w = jnp.max(jnp.array([diag1, diag2])) + f = jnp.min(jnp.array([diag1, diag2])) + + width = (w + f) / 2 + distance_to_next = 1 - (Px - inds) # always in (0, 1] + weights = jnp.minimum(distance_to_next, width) / width + + return inds, weights + + +class Parallel3dProjector(LinearOperator): + r"""General-purpose, 3D, parallel ray X-ray projector. - # weights - Pdx = jnp.array(u) * jnp.array(dx) - diag1 = jnp.abs(Pdx[0] + Pdx[1]) - diag2 = jnp.abs(Pdx[0] - Pdx[1]) - w = jnp.max(jnp.array([diag1, diag2])) - f = jnp.min(jnp.array([diag1, diag2])) + For each view, the projection geometry is specified by an array + with shape (2, 4) that specifies a :math:`2 \times 3` projection + matrix and a :math:`2 \times 1` offset vector. Denoting the matrix + by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel at array + index `(i, j, k)` has its center projected to the detector coordinates - width = (w + f) / 2 - distance_to_next = 1 - (Px - inds) # always in (0, 1] - weights = jnp.minimum(distance_to_next, width) / width + .. math:: + \mathbf{M} \begin{bmatrix} + i + \frac{1}{2} \\ j + \frac{1}{2} \\ k + \frac{1}{2} + \end{bmatrix} + \mathbf{t} \,. - return inds, weights + The detector pixel at index `(i, j)` covers detector coordinates + :math:`[i+1) \times [j+1)`. + :meth:`Parallel3dProjector.matrices_from_euler_angles` can help to + make these geometry arrays. -class Parallel3dProjector: - """General-purpose, 3D, parallel ray X-ray projector.""" + + + + """ def __init__( self, @@ -253,7 +259,7 @@ def __init__( r""" Args: input_shape: Shape of input image. - matrices: (num_angles, 2, 4) array of homogeneous projection matrices. + matrices: (num_views, 2, 4) array of homogeneous projection matrices. det_shape: Shape of detector. """ @@ -262,10 +268,21 @@ def __init__( self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) + super().__init__( + input_shape=self.input_shape, + output_shape=self.output_shape, + eval_fn=self.project, + adj_fn=self.back_project, + ) + def project(self, im): """Compute X-ray projection.""" return Parallel3dProjector._project(im, self.matrices, self.det_shape) + def back_project(self, proj): + """Compute X-ray back projection""" + return Parallel3dProjector._back_project(proj, self.matrices, self.input_shape) + @staticmethod def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: r""" @@ -312,10 +329,6 @@ def _project_single( proj = proj.at[ul_ind[0] + 1, ul_ind[1] + 1].add(lr_weight * im, mode="drop") return proj - def back_project(self, proj): - """Compute X-ray back projection""" - return Parallel3dProjector._back_project(proj, self.matrices, self.input_shape) - @staticmethod def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike: r""" @@ -385,10 +398,18 @@ def _calc_weights(input_shape, matrix, output_shape, slice_offset: int = 0): @staticmethod def matrices_from_euler_angles( - input_shape: Shape, output_shape: Shape, seq: str, angles: ArrayLike, degrees: bool = False + input_shape: Shape, + output_shape: Shape, + seq: str, + angles: ArrayLike, + degrees: bool = False, + voxel_spacing: ArrayLike = None, + det_spacing: ArrayLike = None, ): """ - Create a set of projection matrices from Euler angles. + Create a set of projection matrices from Euler angles. The + input voxels will undergo the specified rotation and then be + projected onto the global xy-plane. Args: input_shape: Shape of input image. @@ -396,22 +417,39 @@ def matrices_from_euler_angles( str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'} for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and intrinsic rotations cannot be mixed in one function call. - angles: (num_angles, N), N = 1, 2, or 3 Euler angles. + angles: (num_views, N), N = 1, 2, or 3 Euler angles. degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians. + voxel_spacing: (3,) array giving the spacing of image + voxels. Default: `[1.0, 1.0, 1.0]`. Experimental. + det_spacing: (2,) array giving the spacing of detector + pixels. Default: `[1.0, 1.0]`. Experimental. + Returns: - (num_angles, 2, 4) array of homogeneous projection matrices. + (num_views, 2, 4) array of homogeneous projection matrices. """ + if voxel_spacing is None: + voxel_spacing = np.ones(3) + + if det_spacing is None: + det_spacing = np.ones(2) + # make projection matrix: form a rotation matrix and chop off the last row - matrices = jnp.stack( - [Rotation.from_euler(seq, angles_i, degrees=degrees).as_matrix() for angles_i in angles] - ) - matrices = matrices[:, :2, :] + matrices = Rotation.from_euler(seq, angles, degrees=degrees).as_matrix() + matrices = matrices[:, :2, :] # (num_views, 2, 3) + + # handle scaling + M_voxel = np.diag(voxel_spacing) # (3, 3) + M_det = np.diag(1 / np.array(det_spacing)) # (2, 2) + + # idea: M_det * M * M_voxel, but with a leading batch dimension + matrices = np.einsum("vmn,nn->vmn", matrices, M_voxel) + matrices = np.einsum("mm,vmn->vmn", M_det, matrices) - # add translation - x0 = jnp.array(input_shape) / 2 - t = -jnp.tensordot(matrices, x0, axes=[2, 0]) + jnp.array(output_shape) / 2 - matrices = jnp.concatenate((matrices, t[..., np.newaxis]), axis=2) + # add translation to line up the centers + x0 = np.array(input_shape) / 2 + t = -np.einsum("vmn,n->vm", matrices, x0) + np.array(output_shape) / 2 + matrices = np.concatenate((matrices, t[..., np.newaxis]), axis=2) return matrices diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 4d5692645..ab75f27dd 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -83,13 +83,21 @@ def convert_from_scico_geometry( # ray is perpendicular to projection axes ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3]) # detector center comes from lifting the center index to 3D - y_center = np.array(det_shape) / 2 + y_center = (np.array(det_shape) - 1) / 2 x_center = ( - np.einsum("...mn,n->...m", matrices[..., :3], np.array(in_shape) / 2) + matrices[..., 3] + np.einsum("...mn,n->...m", matrices[..., :3], (np.array(in_shape) - 1) / 2) + + matrices[..., 3] ) d = np.einsum("...mn,...m->...n", matrices[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2) - u = -matrices[:, 1, :3] - v = -matrices[:, 0, :3] + u = matrices[:, 1, :3] + v = matrices[:, 0, :3] + + # handle different axis conventions + ray = ray[:, [2, 1, 0]] + d = d[:, [2, 1, 0]] + u = u[:, [2, 1, 0]] + v = v[:, [2, 1, 0]] + vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12) return vectors diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index 72d6ec7d2..e575fac06 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -104,7 +104,9 @@ def test_grad(testobj): A = testobj.A x = testobj.x g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2 - np.testing.assert_allclose(scico.grad(g)(x), 2 * A.adj(A(x)), rtol=get_tol()) + np.testing.assert_allclose( + scico.grad(g)(x), 2 * A.adj(A(x)), atol=get_tol() * x.max(), rtol=np.inf + ) def test_adjoint_grad(testobj): @@ -253,6 +255,31 @@ def test_project_coords(test_geometry): np.testing.assert_array_equal(x_proj_gt, x_proj) +def test_convert_to_scico_geometry(test_geometry): + """ + Basic regression test, `test_project_coords` tests the logic. + """ + vol_geom, proj_geom = test_geometry + matrices_truth = scico.linop.xray.astra.convert_to_scico_geometry(vol_geom, proj_geom) + truth = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]]) + np.testing.assert_allclose(matrices_truth, truth) + + +def test_convert_from_scico_geometry(test_geometry): + """ + Basic regression test, `test_project_coords` tests the logic. + """ + in_shape = (30, 31, 32) + matrices = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]]) + det_shape = (31, 32) + vectors = scico.linop.xray.astra.convert_from_scico_geometry(in_shape, matrices, det_shape) + + _, proj_geom_truth = test_geometry + # skip testing element 5, as it is detector center along the ray and doesn't matter + np.testing.assert_allclose(vectors[0, :5], proj_geom_truth["Vectors"][0, :5]) + np.testing.assert_allclose(vectors[0, 6:], proj_geom_truth["Vectors"][0, 6:]) + + def test_ensure_writeable(): assert isinstance(_ensure_writeable(np.ones((2, 1))), np.ndarray) assert isinstance(_ensure_writeable(snp.ones((2, 1))), np.ndarray) diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 6d9c2ba39..5adbd84ea 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -1,9 +1,11 @@ +import numpy as np + import jax.numpy as jnp import pytest import scico -from scico.linop import Parallel2dProjector, XRayTransform +from scico.linop import Parallel2dProjector, Parallel3dProjector @pytest.mark.filterwarnings("error") @@ -11,22 +13,18 @@ def test_init(): input_shape = (3, 3) # no warning with default settings, even at 45 degrees - H = XRayTransform(Parallel2dProjector(input_shape, jnp.array([jnp.pi / 4]))) + H = Parallel2dProjector(input_shape, jnp.array([jnp.pi / 4])) # no warning if we project orthogonally with oversized pixels - H = XRayTransform(Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1, 1]))) + H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1, 1])) # warning if the projection angle changes with pytest.warns(UserWarning): - H = XRayTransform( - Parallel2dProjector(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1])) - ) + H = Parallel2dProjector(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1])) # warning if the pixels get any larger with pytest.warns(UserWarning): - H = XRayTransform( - Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1])) - ) + H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1])) def test_apply(): @@ -37,13 +35,13 @@ def test_apply(): angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection - H = XRayTransform(Parallel2dProjector(x.shape, angles)) + H = Parallel2dProjector(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) # fixed det_count det_count = 14 - H = XRayTransform(Parallel2dProjector(x.shape, angles, det_count=det_count)) + H = Parallel2dProjector(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count @@ -56,7 +54,7 @@ def test_apply_adjoint(): angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection - H = XRayTransform(Parallel2dProjector(x.shape, angles)) + H = Parallel2dProjector(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) @@ -68,6 +66,49 @@ def test_apply_adjoint(): # fixed det_length det_count = 14 - H = XRayTransform(Parallel2dProjector(x.shape, angles, det_count=det_count)) + H = Parallel2dProjector(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count + + +def test_3d_scaling(): + x = jnp.zeros((4, 4, 1)) + x = x.at[1:3, 1:3, 0].set(1.0) + + input_shape = x.shape + output_shape = x.shape[:2] + + # default spacing + M = Parallel3dProjector.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) + H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + # fmt: off + truth = jnp.array( + [[[0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger voxels in the x (first index) direction + M = Parallel3dProjector.matrices_from_euler_angles( + input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] + ) + H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + # fmt: off + truth = jnp.array( + [[[0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger detector pixels in the x (first index) direction + M = Parallel3dProjector.matrices_from_euler_angles( + input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0] + ) + H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + # fmt: off + truth = None # fmt: on # TODO: Check this case more closely. + # np.testing.assert_allclose(H @ x, truth)