From 27e2aecd9e69593e3640b53588e659850eb6fc6a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 2 Nov 2023 14:53:03 -0600 Subject: [PATCH] Various improvements related to the new X-ray transform implemementation. (#461) * Rename TomographicProjector to XRayTransform * Remove markup from exception messages and other minor edits * Rename modules * Rename files and rename TomographicProjector to XRayTransform * Rename AbelProjector to AbelTransform * Docs edits * Update change summary * Overlooked rename changes * Replace Radon transform label * Update submodule * Renaming of some CT projectors (#453) * Add 2d projector and code to time it * Clean up * Add back projection * Add test * Add timing results to example * Start to add new example * Update data * Address mypy * Address isort * Try to fix tables in notebook * Update submodule * Rename XRayProject to XRayTransform * Minor edits --------- Co-authored-by: Michael McCann Co-authored-by: Michael-T-McCann Co-authored-by: Michael McCann * Restructure X-ray transform modules * Update submodule * Adjust angles to be equivalent between scico and astra projections * Clean up output * New example script * Minor edits * Shorten class name * Clarify Parallel2dProjector angles * Docs edits * Remove problematic jit * New example script * Docs improvement * Rename parameter * Docs improvement * Typo fix * Add noise * Add example script * Change noise level * Update notebooks * Docs fix * Update submodule * Add warning to api docs * Add overlooked change summary entry * Remove unintentionally added file * Remove unintentionally added files * Remove unintentionally added files * Address review comment * Remove unintentionally added file * Fix docs typo * Update submodule * Update submodule --------- Co-authored-by: Michael McCann Co-authored-by: Michael-T-McCann Co-authored-by: Michael McCann --- CHANGES.rst | 5 + data | 2 +- docs/source/examples.rst | 3 +- docs/source/inverse.rst | 6 +- docs/source/notes.rst | 22 ++- examples/scripts/README.rst | 7 +- examples/scripts/ct_abel_tv_admm.py | 4 +- examples/scripts/ct_abel_tv_admm_tune.py | 6 +- examples/scripts/ct_astra_3d_tv_admm.py | 12 +- examples/scripts/ct_astra_modl_train_foam2.py | 6 +- examples/scripts/ct_astra_noreg_pcg.py | 8 +- examples/scripts/ct_astra_odp_train_foam2.py | 6 +- examples/scripts/ct_astra_tv_admm.py | 10 +- examples/scripts/ct_astra_weighted_tv_admm.py | 14 +- .../ct_fan_svmbir_ppp_bm3d_admm_prox.py | 6 +- examples/scripts/ct_multi_cs_tv_admm.py | 162 ++++++++++++++++++ examples/scripts/ct_multi_tv_admm.py | 156 +++++++++++++++++ examples/scripts/ct_projector_comparison.py | 150 ++++++++-------- .../scripts/ct_svmbir_ppp_bm3d_admm_cg.py | 4 +- .../scripts/ct_svmbir_ppp_bm3d_admm_prox.py | 6 +- examples/scripts/ct_svmbir_tv_multi.py | 6 +- examples/scripts/ct_tv_admm.py | 139 +++++++++++++++ examples/scripts/index.rst | 3 +- scico/flax/examples/data_generation.py | 4 +- scico/linop/__init__.py | 8 +- scico/linop/abel.py | 12 +- scico/linop/xray/__init__.py | 26 +++ scico/linop/{ => xray}/_xray.py | 49 +++--- scico/linop/{radon_astra.py => xray/astra.py} | 48 +++--- .../linop/{radon_svmbir.py => xray/svmbir.py} | 30 ++-- scico/operator/_operator.py | 2 +- scico/optimize/_admmaux.py | 1 - scico/test/flax/test_inv.py | 4 +- scico/test/linop/test_abel.py | 12 +- scico/test/linop/test_xray.py | 26 --- .../test_astra.py} | 16 +- .../test_svmbir.py} | 6 +- scico/test/linop/xray/test_xray.py | 26 +++ 38 files changed, 761 insertions(+), 252 deletions(-) create mode 100644 examples/scripts/ct_multi_cs_tv_admm.py create mode 100644 examples/scripts/ct_multi_tv_admm.py create mode 100644 examples/scripts/ct_tv_admm.py create mode 100644 scico/linop/xray/__init__.py rename scico/linop/{ => xray}/_xray.py (66%) rename scico/linop/{radon_astra.py => xray/astra.py} (85%) rename scico/linop/{radon_svmbir.py => xray/svmbir.py} (95%) delete mode 100644 scico/test/linop/test_xray.py rename scico/test/linop/{test_radon_astra.py => xray/test_astra.py} (87%) rename scico/test/linop/{test_radon_svmbir.py => xray/test_svmbir.py} (98%) create mode 100644 scico/test/linop/xray/test_xray.py diff --git a/CHANGES.rst b/CHANGES.rst index a413f53cb..7dc048a0a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,11 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- +• New integrated Radon/X-ray transform ``linop.XRayTransform``. +• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and + ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes + to ``XRayTransform``. +• Rename ``AbelProjector`` to ``AbelTransform``. • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19. diff --git a/data b/data index b63329c3b..23b76fd4f 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit b63329c3b1b89fbebc4cb3ec892badee0b989e40 +Subproject commit 23b76fd4fa092c186689af1dba9d058d6dc433bc diff --git a/docs/source/examples.rst b/docs/source/examples.rst index c6cc70e9a..291eeb5a6 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -36,7 +36,8 @@ Computed Tomography examples/ct_astra_odp_train_foam2 examples/ct_astra_unet_train_foam2 examples/ct_projector_comparison - + examples/ct_multi_cs_tv_admm + examples/ct_multi_tv_admm Deconvolution ^^^^^^^^^^^^^ diff --git a/docs/source/inverse.rst b/docs/source/inverse.rst index 1542f4f53..696657412 100644 --- a/docs/source/inverse.rst +++ b/docs/source/inverse.rst @@ -47,15 +47,15 @@ SCICO provides the :class:`.Operator` and :class:`.LinearOperator` classes, which may be subclassed by users, in order to implement the forward operator, :math:`A`. It also has several built-in operators, most of which are linear, e.g., finite convolutions, discrete Fourier -transforms, optical propagators, Abel transforms, and Radon -transforms. For example, +transforms, optical propagators, Abel transforms, and X-ray transforms +(the same as Radon transforms in 2D). For example, .. code:: python input_shape = (512, 512) angles = np.linspace(0, 2 * np.pi, 180, endpoint=False) channels = 512 - A = scico.linop.radon_svmbir.ParallelBeamProjector(input_shape, angles, channels) + A = scico.linop.xray.svmbir.XRayTransform(input_shape, angles, channels) defines a tomographic projection operator. diff --git a/docs/source/notes.rst b/docs/source/notes.rst index 08a63be0a..2986f5adc 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -111,13 +111,25 @@ via interfaces to the `bm3d `__ and when the full benefits of JAX-based code are required. -Tomographic Projectors ----------------------- - -The :class:`.radon_svmbir.TomographicProjector` class is implemented +Tomographic Projectors/Radon Transforms +--------------------------------------- + +Note that the tomographic projections that are frequently referred +to as Radon transforms are referred to as X-ray transforms in SCICO. +While the Radon transform is far more well-known than the X-ray +transform, which is the same as the Radon transform for projections +in two dimensions, these two transform differ in higher numbers of +dimensions, and it is the X-ray transform that is the appropriate +mathematical model for beam attenuation based imaging in three or +more dimensions. + +SCICO includes three different implementations of X-ray transforms. +Of these, :class:`.linop.XRayTransform` is an integral component of +SCICO, while the other two depend on external packages. +The :class:`.xray.svmbir.XRayTransform` class is implemented via an interface to the `svmbir `__ package. The -:class:`.radon_astra.TomographicProjector` class is implemented via an +:class:`.xray.astra.XRayTransform` class is implemented via an interface to the `ASTRA toolbox `__. This toolbox does provide some GPU acceleration support, but efficiency is expected to be lower than diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 3910e9671..1cdb0e8e1 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -36,8 +36,11 @@ Computed Tomography `ct_astra_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `ct_projector_comparison.py `_ - X-ray Projector Comparison - + X-ray Transform Comparison + `ct_multi_cs_tv_admm.py `_ + TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram) + `ct_multi_tv_admm.py `_ + TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) Deconvolution ^^^^^^^^^^^^^ diff --git a/examples/scripts/ct_abel_tv_admm.py b/examples/scripts/ct_abel_tv_admm.py index 97ca30169..2adc141ce 100644 --- a/examples/scripts/ct_abel_tv_admm.py +++ b/examples/scripts/ct_abel_tv_admm.py @@ -26,7 +26,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_circular_phantom -from scico.linop.abel import AbelProjector +from scico.linop.abel import AbelTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -40,7 +40,7 @@ """ Set up the forward operator and create a test measurement. """ -A = AbelProjector(x_gt.shape) +A = AbelTransform(x_gt.shape) y = A @ x_gt np.random.seed(12345) y = y + np.random.normal(size=y.shape).astype(np.float32) diff --git a/examples/scripts/ct_abel_tv_admm_tune.py b/examples/scripts/ct_abel_tv_admm_tune.py index ab7ffd18f..c60ade412 100644 --- a/examples/scripts/ct_abel_tv_admm_tune.py +++ b/examples/scripts/ct_abel_tv_admm_tune.py @@ -38,7 +38,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_circular_phantom -from scico.linop.abel import AbelProjector +from scico.linop.abel import AbelTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.ray import tune @@ -52,7 +52,7 @@ """ Set up the forward operator and create a test measurement. """ -A = AbelProjector(x_gt.shape) +A = AbelTransform(x_gt.shape) y = A @ x_gt np.random.seed(12345) y = y + np.random.normal(size=y.shape).astype(np.float32) @@ -84,7 +84,7 @@ def setup(self, config, x_gt, x0, y): # Get arrays passed by tune call. self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y) # Set up problem to be solved. - self.A = AbelProjector(self.x_gt.shape) + self.A = AbelTransform(self.x_gt.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.C = linop.FiniteDifference(input_shape=self.x_gt.shape) self.reset_config(config) diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py index 3abb9ae89..bb64ea61b 100644 --- a/examples/scripts/ct_astra_3d_tv_admm.py +++ b/examples/scripts/ct_astra_3d_tv_admm.py @@ -15,9 +15,9 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ -where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is -a 3D finite difference operator, and $\mathbf{x}$ is the desired -image. +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $C$ is a 3D finite difference operator, +and $\mathbf{x}$ is the desired image. """ @@ -28,7 +28,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_tangle_phantom -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -43,9 +43,7 @@ n_projection = 10 # number of projections angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = TomographicProjector( - tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles -) # Radon transform operator +A = XRayTransform(tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles) # CT projection operator y = A @ tangle # sinogram diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index 4214888ab..66a137e9c 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -54,7 +54,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform """ Prepare parallel processing. Set an arbitrary processor count (only @@ -81,12 +81,12 @@ Build CT projection operator. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = TomographicProjector( +A = XRayTransform( input_shape=(N, N), detector_spacing=1, det_count=N, angles=angles, -) # Radon transform operator +) # CT projection operator A = (1.0 / N) * A # normalized diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py index fc5dd6f08..9e78f59fd 100644 --- a/examples/scripts/ct_astra_noreg_pcg.py +++ b/examples/scripts/ct_astra_noreg_pcg.py @@ -15,8 +15,8 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 \;,$$ -where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, and -$\mathbf{x}$ is the reconstructed image. +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, and $\mathbf{x}$ is the reconstructed image. """ from time import time @@ -29,7 +29,7 @@ from scico import loss, plot from scico.linop import CircularConvolve -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform from scico.solver import cg """ @@ -45,7 +45,7 @@ """ n_projection = N # matches the phantom size so this is not few-view CT angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = 1 / N * TomographicProjector(x_gt.shape, 1, N, angles) # Radon transform operator +A = 1 / N * XRayTransform(x_gt.shape, 1, N, angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index 77b241dd2..bb9b1a54b 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -58,7 +58,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform """ Prepare parallel processing. Set an arbitrary processor count (only @@ -85,12 +85,12 @@ Build CT projection operator. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = TomographicProjector( +A = XRayTransform( input_shape=(N, N), detector_spacing=1, det_count=N, angles=angles, -) # Radon transform operator +) # CT projection operator A = (1.0 / N) * A # normalized diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index 1f12f7ab3..69520f872 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -14,9 +14,9 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ -where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is -a 2D finite difference operator, and $\mathbf{x}$ is the desired -image. +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and +$\mathbf{x}$ is the desired image. """ import numpy as np @@ -26,7 +26,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -44,7 +44,7 @@ """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = TomographicProjector(x_gt.shape, 1, N, angles) # Radon transform operator +A = XRayTransform(x_gt.shape, 1, N, angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index 3f14b828d..b3dc439c2 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -14,11 +14,11 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ -where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, the norm -weighting $W$ is chosen so that the weighted norm is an approximation to -the Poisson negative log likelihood :cite:`sauer-1993-local`, $C$ is -a 2D finite difference operator, and $\mathbf{x}$ is the desired -image. +where $A$ is the X-ray transform (the CT forward projection), +$\mathbf{y}$ is the sinogram, the norm weighting $W$ is chosen so that +the weighted norm is an approximation to the Poisson negative log +likelihood :cite:`sauer-1993-local`, $C$ is a 2D finite difference +operator, and $\mathbf{x}$ is the desired image. """ import numpy as np @@ -27,7 +27,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.radon_astra import TomographicProjector +from scico.linop.xray.astra import XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -51,7 +51,7 @@ 𝛼 = 1e-2 # attenuation coefficient angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles -A = TomographicProjector(x_gt.shape, 1.0, N, angles) # Radon transform operator +A = XRayTransform(x_gt.shape, 1.0, N, angles) # CT projection operator y_c = A @ x_gt # sinogram diff --git a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py index 1e334ada5..80299a1ae 100644 --- a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py @@ -35,7 +35,7 @@ from scico import metric, plot from scico.functional import BM3D from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import SVMBIRExtendedLoss, TomographicProjector +from scico.linop.xray.svmbir import SVMBIRExtendedLoss, XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -65,7 +65,7 @@ dist_source_detector = 1500.0 magnification = 1.2 -A_fan = TomographicProjector( +A_fan = XRayTransform( x_gt.shape, angles, num_channels, @@ -73,7 +73,7 @@ dist_source_detector=dist_source_detector, magnification=magnification, ) -A_parallel = TomographicProjector( +A_parallel = XRayTransform( x_gt.shape, angles, num_channels, diff --git a/examples/scripts/ct_multi_cs_tv_admm.py b/examples/scripts/ct_multi_cs_tv_admm.py new file mode 100644 index 000000000..f7fcfcc10 --- /dev/null +++ b/examples/scripts/ct_multi_cs_tv_admm.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram) +=================================================================================== + +This example demonstrates solution of a sparse-view CT reconstruction +problem with isotropic total variation (TV) regularization + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} + \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ + +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and +$\mathbf{x}$ is the desired image. The solution is computed and compared +for all three 2D CT projectors available in scico, using a sinogram +computed with the svmbir projector. +""" + +import numpy as np + +import jax + +from xdesign import Foam, discrete_phantom + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 512 # phantom size +np.random.seed(1234) +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) + + +""" +Define CT geometry and construct array of (approximately) equivalent projectors. +""" +n_projection = 45 # number of projections +angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +projectors = { + "astra": astra.XRayTransform(x_gt.shape, 1, N, angles - np.pi / 2.0), # astra + "svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir + "scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico +} + + +""" +Compute common sinogram using svmbir projector. +""" +A = projectors["svmbir"] +noise = np.random.normal(size=(n_projection, N)).astype(np.float32) +y = A @ x_gt + 2.0 * noise + + +""" +Solve the same problem using the different projectors. +""" +print(f"Solving on {device_info()}") +x_rec, hist = {}, {} +for p in ("astra", "svmbir", "scico"): + print(f"\nSolving with {p} projector") + + # Set up ADMM solver object. + λ = 2e0 # L1 norm regularization parameter + ρ = 5e0 # ADMM penalty parameter + maxiter = 25 # number of ADMM iterations + cg_tol = 1e-4 # CG relative tolerance + cg_maxiter = 25 # maximum CG iterations per ADMM iteration + + # The append=0 option makes the results of horizontal and vertical + # finite differences the same shape, which is required for the L21Norm, + # which is used so that g(Cx) corresponds to isotropic TV. + C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) + g = λ * functional.L21Norm() + A = projectors[p] + f = loss.SquaredL2Loss(y=y, A=A) + x0 = snp.clip(A.T(y), 0, 1.0) + + # Set up the solver. + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[ρ], + x0=x0, + maxiter=maxiter, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), + itstat_options={"display": True, "period": 5}, + ) + + # Run the solver. + solver.solve() + hist[p] = solver.itstat_object.history(transpose=True) + x_rec[p] = snp.clip(solver.x, 0, 1.0) + + +""" +Display sinogram. +""" +fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3)) +plot.imview(y, title="sinogram", fig=fig, ax=ax) +fig.show() + + +""" +Plot convergence statistics. +""" +fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5)) +plot.plot( + np.vstack([hist[p].Objective for p in projectors.keys()]).T, + title="Objective function", + xlbl="Iteration", + ylbl="Functional value", + lgnd=projectors.keys(), + fig=fig, + ax=ax[0], +) +plot.plot( + np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T, + ptyp="semilogy", + title="Primal Residual", + xlbl="Iteration", + fig=fig, + ax=ax[1], +) +plot.plot( + np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T, + ptyp="semilogy", + title="Dual Residual", + xlbl="Iteration", + fig=fig, + ax=ax[2], +) +fig.show() + + +""" +Show the recovered images. +""" +fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) +for n, p in enumerate(projectors.keys()): + plot.imview( + x_rec[p], + title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])), + fig=fig, + ax=ax[n + 1], + ) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py new file mode 100644 index 000000000..8ed284c8d --- /dev/null +++ b/examples/scripts/ct_multi_tv_admm.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) +================================================================== + +This example demonstrates solution of a sparse-view CT reconstruction +problem with isotropic total variation (TV) regularization + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} + \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ + +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and +$\mathbf{x}$ is the desired image. The solution is computed and compared +for all three 2D CT projectors available in scico. +""" + +import numpy as np + +import jax + +from xdesign import Foam, discrete_phantom + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 512 # phantom size +np.random.seed(1234) +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) + + +""" +Define CT geometry and construct array of (approximately) equivalent projectors. +""" +n_projection = 45 # number of projections +angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +projectors = { + "astra": astra.XRayTransform(x_gt.shape, 1, N, angles - np.pi / 2.0), # astra + "svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir + "scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico +} + + +""" +Solve the same problem using the different projectors. +""" +print(f"Solving on {device_info()}") +y, x_rec, hist = {}, {}, {} +noise = np.random.normal(size=(n_projection, N)).astype(np.float32) +for p in ("astra", "svmbir", "scico"): + print(f"\nSolving with {p} projector") + A = projectors[p] + y[p] = A @ x_gt + 2.0 * noise # sinogram + + # Set up ADMM solver object. + λ = 2e0 # L1 norm regularization parameter + ρ = 5e0 # ADMM penalty parameter + maxiter = 25 # number of ADMM iterations + cg_tol = 1e-4 # CG relative tolerance + cg_maxiter = 25 # maximum CG iterations per ADMM iteration + + # The append=0 option makes the results of horizontal and vertical + # finite differences the same shape, which is required for the L21Norm, + # which is used so that g(Cx) corresponds to isotropic TV. + C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) + g = λ * functional.L21Norm() + f = loss.SquaredL2Loss(y=y[p], A=A) + x0 = snp.clip(A.T(y[p]), 0, 1.0) + + # Set up the solver. + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[ρ], + x0=x0, + maxiter=maxiter, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), + itstat_options={"display": True, "period": 5}, + ) + + # Run the solver. + solver.solve() + hist[p] = solver.itstat_object.history(transpose=True) + x_rec[p] = snp.clip(solver.x, 0, 1.0) + + +""" +Compare sinograms. +""" +fig, ax = plot.subplots(nrows=3, ncols=1, figsize=(15, 10)) +for idx, name in enumerate(projectors.keys()): + plot.imview(y[name], title=f"{name} sinogram", cbar=None, fig=fig, ax=ax[idx]) +fig.show() + + +""" +Plot convergence statistics. +""" +fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5)) +plot.plot( + np.vstack([hist[p].Objective for p in projectors.keys()]).T, + title="Objective function", + xlbl="Iteration", + ylbl="Functional value", + lgnd=projectors.keys(), + fig=fig, + ax=ax[0], +) +plot.plot( + np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T, + ptyp="semilogy", + title="Primal Residual", + xlbl="Iteration", + fig=fig, + ax=ax[1], +) +plot.plot( + np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T, + ptyp="semilogy", + title="Dual Residual", + xlbl="Iteration", + fig=fig, + ax=ax[2], +) +fig.show() + + +""" +Show the recovered images. +""" +fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) +for n, p in enumerate(projectors.keys()): + plot.imview( + x_rec[p], + title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])), + fig=fig, + ax=ax[n + 1], + ) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index 94b8d1d2c..58a31d4cd 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -6,11 +6,11 @@ r""" -X-ray Projector Comparison +X-ray Transform Comparison ========================== -This example compares SCICO's native X-ray projection algorithm to that -of the ASTRA Toolbox. +This example compares SCICO's native X-ray transform algorithm +to that of the ASTRA toolbox. """ import numpy as np @@ -20,9 +20,9 @@ from xdesign import Foam, discrete_phantom +import scico.linop.xray.astra as astra from scico import plot -from scico.linop import ParallelFixedAxis2dProjector, XRayProject -from scico.linop.radon_astra import TomographicProjector +from scico.linop import Parallel2dProjector, XRayTransform from scico.util import Timer """ @@ -30,7 +30,9 @@ """ N = 512 + det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) + x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jnp.array(x_gt) @@ -46,12 +48,12 @@ projectors = {} timer.start("scico_init") -projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles)) +projectors["scico"] = XRayTransform(Parallel2dProjector((N, N), angles)) timer.stop("scico_init") timer.start("astra_init") -projectors["astra"] = TomographicProjector( - (N, N), detector_spacing=1.0, det_count=det_count, angles=angles +projectors["astra"] = astra.XRayTransform( + (N, N), detector_spacing=1.0, det_count=det_count, angles=angles - jnp.pi / 2.0 ) timer.stop("astra_init") @@ -59,9 +61,10 @@ """ Time first projector application, which might include JIT overhead. """ + ys = {} for name, H in projectors.items(): - timer_label = f"{name}_first_proj" + timer_label = f"{name}_first_fwd" timer.start(timer_label) ys[name] = H @ x_gt jax.block_until_ready(ys[name]) @@ -71,9 +74,10 @@ """ Compute average time for a projector application. """ + num_repeats = 3 for name, H in projectors.items(): - timer_label = f"{name}_avg_proj" + timer_label = f"{name}_avg_fwd" timer.start(timer_label) for _ in range(num_repeats): ys[name] = H @ x_gt @@ -82,62 +86,16 @@ timer.td[timer_label] /= num_repeats -""" -Display timing results. - -On our server, the SCICO projection is more than twice -as fast as ASTRA when both are run on the GPU, and about -10% slower when both are run the CPU. - -On our server, using the GPU: -``` -Label Accum. Current -------------------------------------------- -astra_avg_proj 4.62e-02 s Stopped -astra_first_proj 6.92e-02 s Stopped -astra_init 1.36e-03 s Stopped -scico_avg_proj 1.61e-02 s Stopped -scico_first_proj 2.95e-02 s Stopped -scico_init 1.37e+01 s Stopped -``` - -Using the CPU: -``` -Label Accum. Current -------------------------------------------- -astra_avg_proj 9.11e-01 s Stopped -astra_first_proj 9.16e-01 s Stopped -astra_init 1.06e-03 s Stopped -scico_avg_proj 1.03e+00 s Stopped -scico_first_proj 1.04e+00 s Stopped -scico_init 1.00e+01 s Stopped -``` -""" - -print(timer) - - -""" -Show projections. -""" -fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) -plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) -plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) -fig.show() - - """ Time first back projection, which might include JIT overhead. """ -timer = Timer() - y = np.zeros(H.output_shape, dtype=np.float32) y[num_angles // 3, det_count // 2] = 1.0 y = jnp.array(y) HTys = {} for name, H in projectors.items(): - timer_label = f"{name}_first_BP" + timer_label = f"{name}_first_back" timer.start(timer_label) HTys[name] = H.T @ y jax.block_until_ready(ys[name]) @@ -149,7 +107,7 @@ """ num_repeats = 3 for name, H in projectors.items(): - timer_label = f"{name}_avg_BP" + timer_label = f"{name}_avg_back" timer.start(timer_label) for _ in range(num_repeats): HTys[name] = H.T @ y @@ -159,41 +117,79 @@ """ -Display back projection timing results. +Display timing results. -On our server, the SCICO back projection is slow the first time it is -run, probably due to JIT overhead. After the first run, it is an order of -magnitude faster than ASTRA when both are run on the GPU, and about three -times faster when both are run on the CPU. +On our server, the SCICO projection is more than twice as fast as ASTRA +when both are run on the GPU, and about 10% slower when both are run the +CPU. The SCICO back projection is slow the first time it is run, probably +due to JIT overhead. After the first run, it is an order of magnitude +faster than ASTRA when both are run on the GPU, and about three times +faster when both are run on the CPU. On our server, using the GPU: ``` -Label Accum. Current ------------------------------------------ -astra_avg_BP 3.71e-02 s Stopped -astra_first_BP 4.20e-02 s Stopped -scico_avg_BP 1.05e-03 s Stopped -scico_first_BP 7.63e+00 s Stopped +init astra 1.36e-03 s +init scico 1.37e+01 s + +first fwd astra 6.92e-02 s +first fwd scico 2.95e-02 s + +first back astra 4.20e-02 s +first back scico 7.63e+00 s + +avg fwd astra 4.62e-02 s +avg fwd scico 1.61e-02 s + +avg back astra 3.71e-02 s +avg back scico 1.05e-03 s ``` Using the CPU: ``` -Label Accum. Current ------------------------------------------ -astra_avg_BP 9.34e-01 s Stopped -astra_first_BP 9.39e-01 s Stopped -scico_avg_BP 2.62e-01 s Stopped -scico_first_BP 1.00e+01 s Stopped +init astra 1.06e-03 s +init scico 1.00e+01 s + +first fwd astra 9.16e-01 s +first fwd scico 1.04e+00 s + +first back astra 9.39e-01 s +first back scico 1.00e+01 s + +avg fwd astra 9.11e-01 s +avg fwd scico 1.03e+00 s + +avg back astra 9.34e-01 s +avg back scico 2.62e-01 s ``` """ -print(timer) +print(f"init astra {timer.td['astra_init']:.2e} s") +print(f"init scico {timer.td['scico_init']:.2e} s") +print("") +for tstr in ("first", "avg"): + for dstr in ("fwd", "back"): + for pstr in ("astra", "scico"): + print( + f"{tstr:5s} {dstr:4s} {pstr} {timer.td[pstr + '_' + tstr + '_' + dstr]:.2e} s" + ) + print() + + +""" +Show projections. +""" + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) +plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) +plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) +fig.show() """ Show back projections of a single detector element, i.e., a line. """ -fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) for ax_i in ax: diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index d4b2e6050..390925d11 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -31,7 +31,7 @@ from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import SVMBIRSquaredL2Loss, TomographicProjector +from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -53,7 +53,7 @@ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) -A = TomographicProjector(x_gt.shape, angles, num_channels) +A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 787709b86..a6e663a09 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -41,10 +41,10 @@ from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import ( +from scico.linop.xray.svmbir import ( SVMBIRExtendedLoss, SVMBIRSquaredL2Loss, - TomographicProjector, + XRayTransform, ) from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -67,7 +67,7 @@ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) -A = TomographicProjector(x_gt.shape, angles, num_channels) +A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 06d99696d..8592b44ff 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -14,7 +14,7 @@ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ -where $A$ is the Radon transform (implemented using the SVMBIR +where $A$ is the X-ray transform (implemented using the SVMBIR :cite:`svmbir-2020` tomographic projection), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the desired image. @@ -29,7 +29,7 @@ import scico.numpy as snp from scico import functional, linop, metric, plot from scico.linop import Diagonal -from scico.linop.radon_svmbir import SVMBIRSquaredL2Loss, TomographicProjector +from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform from scico.optimize import PDHG, LinearizedADMM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -52,7 +52,7 @@ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32) -A = TomographicProjector(x_gt.shape, angles, num_channels) +A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py new file mode 100644 index 000000000..6aa3474a7 --- /dev/null +++ b/examples/scripts/ct_tv_admm.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +TV-Regularized Sparse-View CT Reconstruction (Integrated Projector) +=================================================================== + +This example demonstrates solution of a sparse-view CT reconstruction +problem with isotropic total variation (TV) regularization + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} + \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ + +where $A$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and +$\mathbf{x}$ is the desired image. This example uses the CT projector +integrated into scico, while the companion +[example script](ct_astra_tv_admm.rst) uses the projector provided by +the astra package. +""" + +import numpy as np + +import jax + +from mpl_toolkits.axes_grid1 import make_axes_locatable +from xdesign import Foam, discrete_phantom + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.linop.xray import Parallel2dProjector, XRayTransform +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 512 # phantom size +np.random.seed(1234) +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU + + +""" +Configure CT projection operator and generate synthetic measurements. +""" +n_projection = 45 # number of projections +angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles +A = XRayTransform(Parallel2dProjector((N, N), angles)) # CT projection operator +y = A @ x_gt # sinogram + + +""" +Set up ADMM solver object. +""" +λ = 2e0 # L1 norm regularization parameter +ρ = 5e0 # ADMM penalty parameter +maxiter = 25 # number of ADMM iterations +cg_tol = 1e-4 # CG relative tolerance +cg_maxiter = 25 # maximum CG iterations per ADMM iteration + +# The append=0 option makes the results of horizontal and vertical +# finite differences the same shape, which is required for the L21Norm, +# which is used so that g(Cx) corresponds to isotropic TV. +C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) +g = λ * functional.L21Norm() + +f = loss.SquaredL2Loss(y=y, A=A) + +x0 = snp.clip(A.T(y), 0, 1.0) + +solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[ρ], + x0=x0, + maxiter=maxiter, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), + itstat_options={"display": True, "period": 5}, +) + + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +solver.solve() +hist = solver.itstat_object.history(transpose=True) +x_reconstruction = snp.clip(solver.x, 0, 1.0) + + +""" +Show the recovered image. +""" + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5)) +plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0]) +plot.imview( + x_reconstruction, + title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)), + fig=fig, + ax=ax[1], +) +divider = make_axes_locatable(ax[1]) +cax = divider.append_axes("right", size="5%", pad=0.2) +fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") +fig.show() + + +""" +Plot convergence statistics. +""" +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) +plot.plot( + hist.Objective, + title="Objective function", + xlbl="Iteration", + ylbl="Functional value", + fig=fig, + ax=ax[0], +) +plot.plot( + snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, + ptyp="semilogy", + title="Residuals", + xlbl="Iteration", + lgnd=("Primal", "Dual"), + fig=fig, + ax=ax[1], +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index f03f8fa28..584711de2 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -23,7 +23,8 @@ Computed Tomography - ct_astra_odp_train_foam2.py - ct_astra_unet_train_foam2.py - ct_projector_comparison.py - + - ct_multi_cs_tv_admm.py + - ct_multi_tv_admm.py Deconvolution ^^^^^^^^^^^^^ diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index b846830c5..5d706aa28 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -48,7 +48,7 @@ have_astra = True if have_astra: - from scico.linop.radon_astra import TomographicProjector + from scico.linop.xray.astra import XRayTransform # Arbitrary process count: only applies if GPU is not available. @@ -210,7 +210,7 @@ def generate_ct_data( angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles gt_sh = (size, size) detector_spacing = 1 - A = TomographicProjector(gt_sh, detector_spacing, size, angles) # Radon transform operator + A = XRayTransform(gt_sh, detector_spacing, size, angles) # Radon transform operator # Compute sinograms in parallel. a_map = lambda v: jnp.atleast_3d(A @ v.squeeze()) diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 0c14de950..598a26aa2 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2022 by SCICO Developers +# Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -19,7 +19,7 @@ from ._matrix import MatrixOperator from ._stack import DiagonalStack, VerticalStack from ._util import jacobian, operator_norm, power_iteration, valid_adjoint -from ._xray import ParallelFixedAxis2dProjector, XRayProject +from .xray import Parallel2dProjector, XRayTransform __all__ = [ "CircularConvolve", @@ -39,8 +39,8 @@ "Sum", "Transpose", "LinearOperator", - "XRayProject", - "ParallelFixedAxis2dProjector", + "XRayTransform", + "Parallel2dProjector", "ComposedLinearOperator", "linop_from_function", "operator_norm", diff --git a/scico/linop/abel.py b/scico/linop/abel.py index 9a94b3735..6aa2846ca 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -27,12 +27,12 @@ from scipy.linalg import solve_triangular -class AbelProjector(LinearOperator): - r"""Abel transform projector based on `PyAbel `_. +class AbelTransform(LinearOperator): + r"""Abel transform based on `PyAbel `_. - Perform Abel transform (parallel beam tomographic projection of - cylindrically symmetric objects) for a 2D image. The input 2D image - is assumed to be centered and left-right symmetric. + Perform Abel transform (parallel beam projection of cylindrically + symmetric objects) for a 2D image. The input 2D image is assumed to + be centered and left-right symmetric. """ def __init__(self, img_shape: Shape): @@ -78,7 +78,7 @@ def inverse(self, y: jax.Array) -> jax.Array: def _pyabel_transform( x: jax.Array, direction: str, proj_mat_quad: jax.Array, symmetry_axis: Optional[list] = None ) -> jax.Array: - """Perform Abel transformations (forward, inverse and transposed). + """Apply Abel transforms (forward, inverse and transposed). This function contains code copied from `PyAbel `_. """ diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py new file mode 100644 index 000000000..49d5752c5 --- /dev/null +++ b/scico/linop/xray/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""X-ray transform classes. + +The tomographic projections that are frequently referred to as Radon +transforms are referred to as X-ray transforms in SCICO. While the Radon +transform is far more well-known than the X-ray transform, which is the +same as the Radon transform for projections in two dimensions, these two +transform differ in higher numbers of dimensions, and it is the X-ray +transform that is the appropriate mathematical model for beam attenuation +based imaging in three or more dimensions. +""" + +import sys + +from ._xray import Parallel2dProjector, XRayTransform + +__all__ = [ + "XRayTransform", + "Parallel2dProjector", +] diff --git a/scico/linop/_xray.py b/scico/linop/xray/_xray.py similarity index 66% rename from scico/linop/_xray.py rename to scico/linop/xray/_xray.py index 40c649cf4..1f4069401 100644 --- a/scico/linop/_xray.py +++ b/scico/linop/xray/_xray.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. +"""X-ray transform classes.""" + -""" -X-ray projector classes. -""" from functools import partial from typing import Optional @@ -20,14 +19,18 @@ from scico.typing import Shape -from ._linop import LinearOperator +from .._linop import LinearOperator -class XRayProject(LinearOperator): - """X-ray projection operator. +class XRayTransform(LinearOperator): + """X-ray transform operator. - Wraps an X-ray projector object in a SCICO - :class:`LinearOperator`. + Wrap an X-ray projector object in a SCICO :class:`LinearOperator`. + **Warning:** Note that the only X-ray projector object currently + supported, :class:`.Parallel2dProjector`, is not a very accurate + approximation of the integral transform representing real projection + imaging, and may therefore not be suitable for real imaging + applications. """ def __init__(self, projector): @@ -35,7 +38,7 @@ def __init__(self, projector): Args: projector: instance of an X-ray projector object to wrap, currently the only option is - :class:`ParallelFixedAxis2dProjector` + :class:`Parallel2dProjector` """ self._eval = projector.project @@ -45,22 +48,26 @@ def __init__(self, projector): ) -class ParallelFixedAxis2dProjector: +class Parallel2dProjector: """Parallel ray, single axis, 2D X-ray projector.""" def __init__( self, im_shape: Shape, angles: ArrayLike, - det_length: Optional[int] = None, + det_count: Optional[int] = None, dither: bool = True, ): r""" Args: im_shape: Shape of input array. - angles: (num_angles,) array of angles in radians. - det_length: Length of detector, in ``None``, defaults to the - length of diagonal of `im_shape`. + angles: (num_angles,) array of angles in radians. Viewing an + (M, N) array as a matrix with M rows and N columns, an + angle of 0 corresponds to summing rows, an angle of pi/2 + corresponds to summing columns, and an angle of pi/4 + corresponds to summing along antidiagonals. + det_count: Number of elements in detector. If ``None``, + defaults to the size of the diagonal of `im_shape`. dither: If ``True`` randomly shift pixel locations to reduce projection artifacts caused by aliasing. """ @@ -71,11 +78,11 @@ def __init__( x0 = -(im_shape - 1) / 2 - if det_length is None: - det_length = int(np.ceil(np.linalg.norm(im_shape))) - self.det_shape = (det_length,) + if det_count is None: + det_count = int(np.ceil(np.linalg.norm(im_shape))) + self.det_shape = (det_count,) - y0 = -det_length / 2 + y0 = -det_count / 2 @jax.vmap def compute_inds(angle: float) -> ArrayLike: @@ -106,7 +113,7 @@ def compute_inds(angle: float) -> ArrayLike: # map negative inds to y_size, which is out of bounds and will be ignored # otherwise they index from the end like x[-1] - inds = jnp.where(inds < 0, det_length, inds) + inds = jnp.where(inds < 0, det_count, inds) return inds @@ -115,7 +122,7 @@ def compute_inds(angle: float) -> ArrayLike: @partial(jax.vmap, in_axes=(None, 0)) def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: """Compute the projection at a single angle.""" - return jnp.zeros(det_length).at[inds].add(im) + return jnp.zeros(det_count).at[inds].add(im) @jax.jit def project(im: ArrayLike) -> ArrayLike: diff --git a/scico/linop/radon_astra.py b/scico/linop/xray/astra.py similarity index 85% rename from scico/linop/radon_astra.py rename to scico/linop/xray/astra.py index 6a5a337b9..b877891ad 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/xray/astra.py @@ -5,9 +5,9 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Radon transform LinearOperator wrapping the ASTRA toolbox. +"""X-ray transform LinearOperator wrapping the ASTRA toolbox. -Radon transform :class:`.LinearOperator` wrapping the parallel beam +X-ray transform :class:`.LinearOperator` wrapping the parallel beam projections in the `ASTRA toolbox `_. This package provides both C and CUDA implementations of core @@ -37,11 +37,11 @@ from scico.typing import Shape -from ._linop import LinearOperator +from .._linop import LinearOperator -class TomographicProjector(LinearOperator): - r"""Parallel beam Radon transform based on the ASTRA toolbox. +class XRayTransform(LinearOperator): + r"""Parallel beam X-ray transform based on the ASTRA toolbox. Perform tomographic projection (also called X-ray projection) of an image or volume at specified angles, using the @@ -61,24 +61,28 @@ def __init__( Args: input_shape: Shape of the input array. Determines whether 2D or 3D algorithm is used. - detector_spacing: Spacing between detector elements. See - https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries + detector_spacing: Spacing between detector elements. See the + astra documentation for more information for + `2d `__ or - https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries - for more information. - det_count: Number of detector elements. See - https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries + `3d `__ + geometries. + det_count: Number of detector elements. See the astra + documentation for more information for + `2d `__ or - https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries - for more information. + `3d `__ + geometries. angles: Array of projection angles in radians. volume_geometry: Specification of the shape of the discretized reconstruction volume. Must either ``None``, in which case it is inferred from `input_shape`, or - follow the astra syntax described in - https://www.astra-toolbox.com/docs/geom2d.html#volume-geometries + follow the astra syntax described in the astra + documentation for + `2d `__ or - https://www.astra-toolbox.com/docs/geom3d.html#d-geometries. + `3d `__ + geometries. device: Specifies device for projection operation. One of ["auto", "gpu", "cpu"]. If "auto", a GPU is used if available, otherwise, the CPU is used. @@ -87,7 +91,7 @@ def __init__( self.num_dims = len(input_shape) if self.num_dims not in [2, 3]: raise ValueError( - f"Only 2D and 3D projections are supported, but `input_shape` is {input_shape}." + f"Only 2D and 3D projections are supported, but input_shape is {input_shape}." ) output_shape: Shape @@ -96,7 +100,7 @@ def __init__( elif self.num_dims == 3: assert isinstance(det_count, (list, tuple)) if len(det_count) != 2: - raise ValueError("Expected `det_count` to have 2 elements") + raise ValueError("Expected det_count to have 2 elements") output_shape = (det_count[0], len(angles), det_count[1]) # Set up all the ASTRA config @@ -112,7 +116,7 @@ def __init__( assert isinstance(detector_spacing, (list, tuple)) assert isinstance(det_count, (list, tuple)) if len(detector_spacing) != 2: - raise ValueError("Expected `detector_spacing` to have 2 elements") + raise ValueError("Expected detector_spacing to have 2 elements") self.proj_geom = astra.create_proj_geom( "parallel3d", detector_spacing[0], @@ -132,7 +136,7 @@ def __init__( self.vol_geom: dict = astra.create_vol_geom(*input_shape, *volume_geometry) else: raise ValueError( - "`volume_geometry` must be a tuple of len 4 (2D) or 6 (3D)." + "volume_geometry must be a tuple of len 4 (2D) or 6 (3D)." "Please see the astra documentation for details." ) else: @@ -152,7 +156,7 @@ def __init__( raise ValueError(f"Invalid device specified; got {device}.") if self.num_dims == 3 and self.device == "cpu": - raise ValueError("No CPU algorithm exists for 3D tomography.") + raise ValueError("No CPU algorithm for 3D projection.") if self.num_dims == 3: # not needed for astra's 3D algorithm @@ -227,7 +231,7 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: """ if self.num_dims == 3: - raise NotImplementedError("3D FBP is not implemented") + raise NotImplementedError("3D FBP is not implemented.") # Just use the CPU FBP alg for now; hitting memory issues with GPU one. def f(sino): diff --git a/scico/linop/radon_svmbir.py b/scico/linop/xray/svmbir.py similarity index 95% rename from scico/linop/radon_svmbir.py rename to scico/linop/xray/svmbir.py index 6d81b0fb7..8e757da84 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/xray/svmbir.py @@ -5,9 +5,9 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Tomographic projector LinearOperator wrapping the svmbir package. +"""X-ray transform LinearOperator wrapping the svmbir package. -Tomographic projector :class:`.LinearOperator` wrapping the +X-ray transform :class:`.LinearOperator` wrapping the `svmbir `_ package. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. @@ -24,8 +24,8 @@ from scico.loss import Loss, SquaredL2Loss from scico.typing import Shape -from ._diag import Diagonal, Identity -from ._linop import LinearOperator +from .._diag import Diagonal, Identity +from .._linop import LinearOperator try: import svmbir @@ -33,8 +33,8 @@ raise ImportError("Could not import svmbir; please install it.") -class TomographicProjector(LinearOperator): - r"""Tomographic projector based on svmbir. +class XRayTransform(LinearOperator): + r"""X-ray transform based on svmbir. Perform tomographic projection of an image at specified angles, using the `svmbir `_ package. The @@ -42,7 +42,7 @@ class TomographicProjector(LinearOperator): (pixels outside this region are ignored when performing the projection) is active. This region of validity is also respected by :meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss` - is initialized with a :class:`TomographicProjector` with this option + is initialized with a :class:`XRayTransform` with this option enabled. A brief description of the supported scanner geometries can be found @@ -316,7 +316,7 @@ class SVMBIRExtendedLoss(Loss): \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, - where :math:`A` is a :class:`.TomographicProjector`, + where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. @@ -325,12 +325,12 @@ class SVMBIRExtendedLoss(Loss): :math:`\ell_2` loss as follows. When `positivity=True`, the prox projects onto the non-negative orthant and the loss is infinite if any element of the input is negative. When the `is_masked` option - of the associated :class:`.TomographicProjector` is ``True``, the + of the associated :class:`.XRayTransform` is ``True``, the reconstruction is computed over a masked region of the image as - described in class :class:`.TomographicProjector`. + described in class :class:`.XRayTransform`. """ - A: TomographicProjector + A: XRayTransform W: Union[Identity, Diagonal] def __init__( @@ -358,8 +358,8 @@ def __init__( """ super().__init__(*args, scale=scale, **kwargs) # type: ignore - if not isinstance(self.A, TomographicProjector): - raise ValueError("LinearOperator A must be a radon_svmbir.TomographicProjector.") + if not isinstance(self.A, XRayTransform): + raise ValueError("LinearOperator A must be a radon_svmbir.XRayTransform.") self.has_prox = True @@ -445,7 +445,7 @@ class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, - where :math:`A` is a :class:`.TomographicProjector`, :math:`\alpha` + where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. @@ -473,5 +473,5 @@ def __init__( if self.A.is_masked: raise ValueError( - "Parameter is_masked must be False for the TomographicProjector in SVMBIRSquaredL2Loss." + "Parameter is_masked must be False for the XRayTransform in SVMBIRSquaredL2Loss." ) diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index baef6ec91..eabaf356b 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -383,7 +383,7 @@ def concat_args(args): # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 if isinstance(args, (jnp.ndarray, np.ndarray)): - # In the case that the original operator takes a blockkarray with two + # In the case that the original operator takes a blockarray with two # blocks, wrap in a list so we can use the same indexing as >2 block case args = [args] diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 5e7808989..8cf4eac3d 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -220,7 +220,6 @@ def internal_init(self, admm: soa.ADMM): # hessian = A.T @ W @ A; W may be identity lhs_op += admm.f.hessian - lhs_op.jit() self.lhs_op = lhs_op def compute_rhs(self) -> Union[Array, BlockArray]: diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py index b8fb3462c..2326fc430 100644 --- a/scico/test/flax/test_inv.py +++ b/scico/test/flax/test_inv.py @@ -16,7 +16,7 @@ from scico.linop import CircularConvolve, Identity if have_astra: - from scico.linop.radon_astra import TomographicProjector + from scico.linop.xray.astra import XRayTransform os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" @@ -107,7 +107,7 @@ def setup_method(self, method): self.nproj = 60 # number of projections angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles - self.opCT = TomographicProjector( + self.opCT = XRayTransform( input_shape=(self.N, self.N), detector_spacing=1, det_count=self.N, diff --git a/scico/test/linop/test_abel.py b/scico/test/linop/test_abel.py index ca5323af3..a5024ba7a 100644 --- a/scico/test/linop/test_abel.py +++ b/scico/test/linop/test_abel.py @@ -5,7 +5,7 @@ import pytest import scico.numpy as snp -from scico.linop.abel import AbelProjector +from scico.linop.abel import AbelTransform from scico.test.linop.test_linop import adjoint_test BIG_INPUT = (128, 128) @@ -23,7 +23,7 @@ def make_im(Nx, Ny): @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_inverse(Nx, Ny): im = make_im(Nx, Ny) - A = AbelProjector(im.shape) + A = AbelTransform(im.shape) Ax = A @ im im_hat = A.inverse(Ax) @@ -34,14 +34,14 @@ def test_inverse(Nx, Ny): @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_adjoint(Nx, Ny): im = make_im(Nx, Ny) - A = AbelProjector(im.shape) + A = AbelTransform(im.shape) adjoint_test(A) @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_ATA(Nx, Ny): x = make_im(Nx, Ny) - A = AbelProjector(x.shape) + A = AbelTransform(x.shape) Ax = A(x) ATAx = A.adj(Ax) np.testing.assert_allclose(np.sum(x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5) @@ -52,7 +52,7 @@ def test_grad(Nx, Ny): # ensure that we can take grad on a function using our projector # grad || A(x) ||_2^2 == 2 A.T @ A x x = make_im(Nx, Ny) - A = AbelProjector(x.shape) + A = AbelTransform(x.shape) g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2 np.testing.assert_allclose(jax.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5) @@ -60,7 +60,7 @@ def test_grad(Nx, Ny): @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_adjoint_grad(Nx, Ny): x = make_im(Nx, Ny) - A = AbelProjector(x.shape) + A = AbelTransform(x.shape) Ax = A @ x f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2 np.testing.assert_allclose(jax.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5) diff --git a/scico/test/linop/test_xray.py b/scico/test/linop/test_xray.py deleted file mode 100644 index bb827988f..000000000 --- a/scico/test/linop/test_xray.py +++ /dev/null @@ -1,26 +0,0 @@ -import jax.numpy as jnp - -from scico.linop import ParallelFixedAxis2dProjector, XRayProject - - -def test_apply(): - im_shape = (12, 13) - num_angles = 10 - x = jnp.ones(im_shape) - - angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) - - # general projection - H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles)) - y = H @ x - assert y.shape[0] == (num_angles) - - # fixed det_length - det_length = 14 - H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, det_length=det_length)) - y = H @ x - assert y.shape[1] == det_length - - # dither off - H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, dither=False)) - y = H @ x diff --git a/scico/test/linop/test_radon_astra.py b/scico/test/linop/xray/test_astra.py similarity index 87% rename from scico/test/linop/test_radon_astra.py rename to scico/test/linop/xray/test_astra.py index 42b4e0f57..67699da76 100644 --- a/scico/test/linop/test_radon_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -8,10 +8,10 @@ import scico.numpy as snp from scico.linop import DiagonalStack from scico.test.linop.test_linop import adjoint_test -from scico.test.linop.test_radon_svmbir import make_im +from scico.test.linop.xray.test_svmbir import make_im try: - from scico.linop.radon_astra import TomographicProjector + from scico.linop.xray.astra import XRayTransform except ModuleNotFoundError as e: if e.name == "astra": pytest.skip("astra not installed", allow_module_level=True) @@ -41,7 +41,7 @@ def get_tol_random_input(): return rtol -class TomographicProjectorTest: +class XRayTransformTest: def __init__(self, volume_geometry): N_proj = 180 # number of projection angles N_det = 384 @@ -51,7 +51,7 @@ def __init__(self, volume_geometry): np.random.seed(1234) self.x = np.random.randn(N, N).astype(np.float32) self.y = np.random.randn(N_proj, N_det).astype(np.float32) - self.A = TomographicProjector( + self.A = XRayTransform( input_shape=(N, N), volume_geometry=volume_geometry, detector_spacing=detector_spacing, @@ -62,7 +62,7 @@ def __init__(self, volume_geometry): @pytest.fixture(params=[None, [-N / 2, N / 2, -N / 2, N / 2]]) def testobj(request): - yield TomographicProjectorTest(request.param) + yield XRayTransformTest(request.param) def test_ATA_call(testobj): @@ -125,7 +125,7 @@ def test_adjoint_typical_input(testobj): def test_jit_in_DiagonalStack(): """See https://github.com/lanl/scico/issues/331""" N = 10 - H = DiagonalStack([TomographicProjector((N, N), 1.0, N, snp.linspace(0, snp.pi, N))]) + H = DiagonalStack([XRayTransform((N, N), 1.0, N, snp.linspace(0, snp.pi, N))]) H.T @ snp.zeros(H.output_shape, dtype=snp.float32) @@ -133,13 +133,13 @@ def test_jit_in_DiagonalStack(): def test_3D_on_CPU(): x = snp.zeros((4, 5, 6)) with pytest.raises(ValueError): - A = TomographicProjector(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10)) + A = XRayTransform(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10)) @pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="checking GPU behavior") def test_3D_on_GPU(): x = snp.zeros((4, 5, 6)) - A = TomographicProjector(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10)) + A = XRayTransform(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10)) assert A.num_dims == 3 y = A @ x diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/xray/test_svmbir.py similarity index 98% rename from scico/test/linop/test_radon_svmbir.py rename to scico/test/linop/xray/test_svmbir.py index a41629d6a..9674269c0 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/xray/test_svmbir.py @@ -14,10 +14,10 @@ try: import svmbir - from scico.linop.radon_svmbir import ( + from scico.linop.xray.svmbir import ( SVMBIRExtendedLoss, SVMBIRSquaredL2Loss, - TomographicProjector, + XRayTransform, ) except ImportError as e: pytest.skip("svmbir not installed", allow_module_level=True) @@ -90,7 +90,7 @@ def make_A( ): angles = make_angles(num_angles) - A = TomographicProjector( + A = XRayTransform( im.shape, angles, num_channels, diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py new file mode 100644 index 000000000..288da22b0 --- /dev/null +++ b/scico/test/linop/xray/test_xray.py @@ -0,0 +1,26 @@ +import jax.numpy as jnp + +from scico.linop import Parallel2dProjector, XRayTransform + + +def test_apply(): + im_shape = (12, 13) + num_angles = 10 + x = jnp.ones(im_shape) + + angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) + + # general projection + H = XRayTransform(Parallel2dProjector(x.shape, angles)) + y = H @ x + assert y.shape[0] == (num_angles) + + # fixed det_count + det_count = 14 + H = XRayTransform(Parallel2dProjector(x.shape, angles, det_count=det_count)) + y = H @ x + assert y.shape[1] == det_count + + # dither off + H = XRayTransform(Parallel2dProjector(x.shape, angles, dither=False)) + y = H @ x