From 085bbcc4bfcd9b480ccbdab0f4048791415b8535 Mon Sep 17 00:00:00 2001 From: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:56:47 -0600 Subject: [PATCH 1/3] Add a jax-based X-ray projector (#433) * Add 2d projector and code to time it * Clean up * Add back projection * Add test * Add timing results to example * Start to add new example * Update data * Address mypy * Address isort * Try to fix tables in notebook * Update submodule * Rename parameter * Update ray syntax to get tests passing in CI --------- Co-authored-by: Michael McCann Co-authored-by: Michael McCann Co-authored-by: Brendt Wohlberg --- data | 2 +- docs/source/examples.rst | 1 + examples/scripts/README.rst | 2 + examples/scripts/ct_projector_comparison.py | 209 ++++++++++++++++++++ examples/scripts/index.rst | 1 + scico/linop/__init__.py | 3 + scico/linop/_xray.py | 125 ++++++++++++ scico/test/linop/test_xray.py | 26 +++ scico/test/test_ray_tune.py | 21 +- 9 files changed, 376 insertions(+), 14 deletions(-) create mode 100644 examples/scripts/ct_projector_comparison.py create mode 100644 scico/linop/_xray.py create mode 100644 scico/test/linop/test_xray.py diff --git a/data b/data index c43239596..0d9f1fef8 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit c43239596865c61fbfaf77d6e7ed82c7afd65ea5 +Subproject commit 0d9f1fef8df6eebb98d154e1e6d1ab8357914a88 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 368d8fa98..c6cc70e9a 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -35,6 +35,7 @@ Computed Tomography examples/ct_astra_modl_train_foam2 examples/ct_astra_odp_train_foam2 examples/ct_astra_unet_train_foam2 + examples/ct_projector_comparison Deconvolution diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index f82dd3aa7..3910e9671 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -35,6 +35,8 @@ Computed Tomography CT Training and Reconstructions with ODP `ct_astra_unet_train_foam2.py `_ CT Training and Reconstructions with UNet + `ct_projector_comparison.py `_ + X-ray Projector Comparison Deconvolution diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py new file mode 100644 index 000000000..ab8c43cfa --- /dev/null +++ b/examples/scripts/ct_projector_comparison.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + + +r""" +X-ray Projector Comparison +========================== + +This example compares SCICO's native X-ray projection algorithm +to that of the ASTRA Toolbox. +""" + +import numpy as np + +import jax +import jax.numpy as jnp + +from xdesign import Foam, discrete_phantom + +from scico import plot +from scico.linop import ParallelFixedAxis2dProjector, XRayProject +from scico.linop.radon_astra import TomographicProjector +from scico.util import Timer + +""" +Create a ground truth image. +""" + +N = 512 + + +det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) + +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) + +""" +Time projector instantiation. +""" + +num_angles = 500 +angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) + + +timer = Timer() + +projectors = {} +timer.start("scico_init") +projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles)) +timer.stop("scico_init") + +timer.start("astra_init") +projectors["astra"] = TomographicProjector( + (N, N), detector_spacing=1.0, det_count=det_count, angles=angles +) +timer.stop("astra_init") + +""" +Time first projector application, which might include JIT overhead. +""" + +ys = {} +for name, H in projectors.items(): + timer_label = f"{name}_first_proj" + timer.start(timer_label) + ys[name] = H @ x_gt + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + + +""" +Compute average time for a projector application. +""" + +num_repeats = 3 +for name, H in projectors.items(): + timer_label = f"{name}_avg_proj" + timer.start(timer_label) + for _ in range(num_repeats): + ys[name] = H @ x_gt + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + timer.td[timer_label] /= num_repeats + +""" +Display timing results. + +On our server, the SCICO projection is more than twice +as fast as ASTRA when both are run on the GPU, and about +10% slower when both are run the CPU. + +On our server, using the GPU: +``` +Label Accum. Current +------------------------------------------- +astra_avg_proj 4.62e-02 s Stopped +astra_first_proj 6.92e-02 s Stopped +astra_init 1.36e-03 s Stopped +scico_avg_proj 1.61e-02 s Stopped +scico_first_proj 2.95e-02 s Stopped +scico_init 1.37e+01 s Stopped +``` + +Using the CPU: +``` +Label Accum. Current +------------------------------------------- +astra_avg_proj 9.11e-01 s Stopped +astra_first_proj 9.16e-01 s Stopped +astra_init 1.06e-03 s Stopped +scico_avg_proj 1.03e+00 s Stopped +scico_first_proj 1.04e+00 s Stopped +scico_init 1.00e+01 s Stopped +``` +""" + +print(timer) + +""" +Show projections. +""" + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) +plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) +plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) +fig.show() + + +""" +Time first back projection, which might include JIT overhead. +""" +timer = Timer() + +y = np.zeros(H.output_shape, dtype=np.float32) +y[num_angles // 3, det_count // 2] = 1.0 +y = jax.device_put(y) + +HTys = {} +for name, H in projectors.items(): + timer_label = f"{name}_first_BP" + timer.start(timer_label) + HTys[name] = H.T @ y + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + + +""" +Compute average time for back projection. +""" +num_repeats = 3 +for name, H in projectors.items(): + timer_label = f"{name}_avg_BP" + timer.start(timer_label) + for _ in range(num_repeats): + HTys[name] = H.T @ y + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + timer.td[timer_label] /= num_repeats + +""" +Display back projection timing results. + +On our server, the SCICO back projection is slow +the first time it is run, probably due to JIT overhead. +After the first run, it is an order of magnitude +faster than ASTRA when both are run on the GPU, +and about three times faster when both are run on the CPU. + +On our server, using the GPU: +``` +Label Accum. Current +----------------------------------------- +astra_avg_BP 3.71e-02 s Stopped +astra_first_BP 4.20e-02 s Stopped +scico_avg_BP 1.05e-03 s Stopped +scico_first_BP 7.63e+00 s Stopped +``` + +Using the CPU: +``` +Label Accum. Current +----------------------------------------- +astra_avg_BP 9.34e-01 s Stopped +astra_first_BP 9.39e-01 s Stopped +scico_avg_BP 2.62e-01 s Stopped +scico_first_BP 1.00e+01 s Stopped +``` +""" + +print(timer) + +""" +Show back projections of a single detector element, +i.e., a line. +""" + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) +plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) +plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) +for ax_i in ax: + ax_i.set_xlim(2 * N / 5, N - 2 * N / 5) + ax_i.set_ylim(2 * N / 5, N - 2 * N / 5) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 683bf0893..f03f8fa28 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -22,6 +22,7 @@ Computed Tomography - ct_astra_modl_train_foam2.py - ct_astra_odp_train_foam2.py - ct_astra_unet_train_foam2.py + - ct_projector_comparison.py Deconvolution diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 8fca29a1f..0c14de950 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -19,6 +19,7 @@ from ._matrix import MatrixOperator from ._stack import DiagonalStack, VerticalStack from ._util import jacobian, operator_norm, power_iteration, valid_adjoint +from ._xray import ParallelFixedAxis2dProjector, XRayProject __all__ = [ "CircularConvolve", @@ -38,6 +39,8 @@ "Sum", "Transpose", "LinearOperator", + "XRayProject", + "ParallelFixedAxis2dProjector", "ComposedLinearOperator", "linop_from_function", "operator_norm", diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py new file mode 100644 index 000000000..40c649cf4 --- /dev/null +++ b/scico/linop/_xray.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + + +""" +X-ray projector classes. +""" +from functools import partial +from typing import Optional + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from scico.typing import Shape + +from ._linop import LinearOperator + + +class XRayProject(LinearOperator): + """X-ray projection operator. + + Wraps an X-ray projector object in a SCICO + :class:`LinearOperator`. + """ + + def __init__(self, projector): + r""" + Args: + projector: instance of an X-ray projector object to wrap, + currently the only option is + :class:`ParallelFixedAxis2dProjector` + """ + self._eval = projector.project + + super().__init__( + input_shape=projector.im_shape, + output_shape=(len(projector.angles), *projector.det_shape), + ) + + +class ParallelFixedAxis2dProjector: + """Parallel ray, single axis, 2D X-ray projector.""" + + def __init__( + self, + im_shape: Shape, + angles: ArrayLike, + det_length: Optional[int] = None, + dither: bool = True, + ): + r""" + Args: + im_shape: Shape of input array. + angles: (num_angles,) array of angles in radians. + det_length: Length of detector, in ``None``, defaults to the + length of diagonal of `im_shape`. + dither: If ``True`` randomly shift pixel locations to + reduce projection artifacts caused by aliasing. + """ + self.im_shape = im_shape + self.angles = angles + + im_shape = np.array(im_shape) + + x0 = -(im_shape - 1) / 2 + + if det_length is None: + det_length = int(np.ceil(np.linalg.norm(im_shape))) + self.det_shape = (det_length,) + + y0 = -det_length / 2 + + @jax.vmap + def compute_inds(angle: float) -> ArrayLike: + """Project pixel positions on to a detector at the given + angle, determine which detector element they contribute to. + """ + x = jnp.stack( + jnp.meshgrid( + *( + jnp.arange(shape_i) * step_i + start_i + for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape) + ), + indexing="ij", + ), + axis=-1, + ) + + # dither + if dither: + key = jax.random.PRNGKey(0) + x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) + + # project + Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle) + + # quantize + inds = jnp.floor((Px - y0)).astype(int) + + # map negative inds to y_size, which is out of bounds and will be ignored + # otherwise they index from the end like x[-1] + inds = jnp.where(inds < 0, det_length, inds) + + return inds + + inds = compute_inds(angles) # (len(angles), *im_shape) + + @partial(jax.vmap, in_axes=(None, 0)) + def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: + """Compute the projection at a single angle.""" + return jnp.zeros(det_length).at[inds].add(im) + + @jax.jit + def project(im: ArrayLike) -> ArrayLike: + """Compute the projection for all angles.""" + return project_inds(im, inds) + + self.project = project diff --git a/scico/test/linop/test_xray.py b/scico/test/linop/test_xray.py new file mode 100644 index 000000000..bb827988f --- /dev/null +++ b/scico/test/linop/test_xray.py @@ -0,0 +1,26 @@ +import jax.numpy as jnp + +from scico.linop import ParallelFixedAxis2dProjector, XRayProject + + +def test_apply(): + im_shape = (12, 13) + num_angles = 10 + x = jnp.ones(im_shape) + + angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) + + # general projection + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles)) + y = H @ x + assert y.shape[0] == (num_angles) + + # fixed det_length + det_length = 14 + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, det_length=det_length)) + y = H @ x + assert y.shape[1] == det_length + + # dither off + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, dither=False)) + y = H @ x diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index 682592934..dde5b1d37 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -7,19 +7,18 @@ try: import ray - from scico.ray import report, tune + from scico.ray import train, tune ray.init(num_cpus=1) except ImportError as e: pytest.skip("ray.tune not installed", allow_module_level=True) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_random_run(): - def eval_params(config, reporter): + def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - reporter(cost=cost) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -40,12 +39,11 @@ def eval_params(config, reporter): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_random_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -66,12 +64,11 @@ def eval_params(config): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_run(): - def eval_params(config, reporter): + def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - reporter(cost=cost) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -90,12 +87,11 @@ def eval_params(config, reporter): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -115,12 +111,11 @@ def eval_params(config): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_tune_alt_init(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} tuner = tune.Tuner( From 4deba7bbbb97d489d86c4235e76f8360a1e66a35 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 28 Sep 2023 17:30:00 -0600 Subject: [PATCH 2/3] Bump max jaxlib/jax version (#452) --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 618bcbe11..e62fcb7cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.14 -jax>=0.4.3,<=0.4.14 +jaxlib>=0.4.3,<=0.4.16 +jax>=0.4.3,<=0.4.16 flax>=0.6.1,<=0.6.9 bm3d>=4.0.0 bm4d>=4.2.2 From 216ffc8548d49a38d6f27127f1f4e1bc12ce1069 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 4 Oct 2023 15:52:27 -0600 Subject: [PATCH 3/3] Resolve warnings with jaxlib/jax 0.4.17 (#454) --- scico/numpy/_wrapped_function_lists.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scico/numpy/_wrapped_function_lists.py b/scico/numpy/_wrapped_function_lists.py index 94ce8f4a4..6e9c3a163 100644 --- a/scico/numpy/_wrapped_function_lists.py +++ b/scico/numpy/_wrapped_function_lists.py @@ -102,7 +102,6 @@ "ediff1d", "gradient", "cross", - "trapz", "exp", "expm1", "exp2", @@ -227,7 +226,6 @@ "hstack", "dstack", "column_stack", - "row_stack", "split", "array_split", "dsplit",