diff --git a/examples/scripts/ct_large_projection.py b/examples/scripts/ct_large_projection.py new file mode 100644 index 000000000..60726c0b1 --- /dev/null +++ b/examples/scripts/ct_large_projection.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Large-scale CT Projection +========================= + +This example demonstrates using SCICO's X-ray projector on a large-scale +volume. + +""" + +import numpy as np + +import jax + +from scico.examples import create_block_phantom +from scico.linop import Parallel3dProjector, XRayTransform + +N = 256 +num_views = 3 + +in_shape = (N, N, N) +x = create_block_phantom(in_shape) + +det_shape = (N, N) + +rot_X = 90.0 - 16.0 +rot_Y = np.linspace(0, 180, num_views, endpoint=False) +angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1) +matrices = Parallel3dProjector.matrices_from_euler_angles( + in_shape, det_shape, "XY", angles, degrees=True +) + + +H = XRayTransform(Parallel3dProjector(in_shape, matrices, det_shape)) + +proj = H @ x +jax.block_until_ready(proj) diff --git a/scico/examples.py b/scico/examples.py index 4b68f7410..955b7d5d3 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,10 +12,14 @@ import os import tempfile import zipfile +from functools import partial from typing import List, Optional, Tuple, Union import numpy as np +import jax +import jax.numpy as jnp + import imageio.v3 as iio import scico.numpy as snp @@ -551,6 +555,7 @@ def create_tangle_phantom(nx: int, ny: int, nz: int) -> snp.Array: return (values < 2.0).astype(float) +@partial(jax.jit, static_argnums=0) def create_block_phantom(out_shape: Shape) -> snp.Array: """Construct a blocky 3D phantom. @@ -562,7 +567,7 @@ def create_block_phantom(out_shape: Shape) -> snp.Array: """ # make the phantom at a low resolution - low_res = np.array( + low_res = jnp.array( [ [ [0.0, 0.0, 0.0], @@ -581,11 +586,11 @@ def create_block_phantom(out_shape: Shape) -> snp.Array: ], ] ) - low_res = np.pad(low_res, 1) - - # upsample it to the requested resolution - full_res = zoom(low_res, np.array(out_shape) / low_res.shape, order=0) - return full_res + positions = jnp.stack( + jnp.meshgrid(*[jnp.linspace(-0.5, 2.5, s) for s in out_shape], indexing="ij") + ) + indices = jnp.round(positions).astype(int) + return low_res[indices[0], indices[1], indices[2]] def spnoise( diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 876f2ad01..fbf8577a7 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -135,7 +135,7 @@ def _project(im, x0, dx, y0, ny, angles): y0: Location of the edge of the first detector bin. ny: Number of detector bins. angles: (num_angles,) array of angles in radians. Pixels are - projected onto units vectors pointing in these directions. + projected onto unit vectors pointing in these directions. """ nx = im.shape inds, weights = _calc_weights(x0, dx, nx, angles, y0) @@ -228,38 +228,53 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: r""" Args: im: Input image. - matrices: (num_angles, 2, 4) array of homogeneous projection matrices. + matrix: (num_views, 2, 4) array of homogeneous projection matrices. det_shape: Shape of detector. """ + num_views = len(matrices) + proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype) + for view_ind, matrix in enumerate(matrices): + proj = proj.at[view_ind].set( + Parallel3dProjector._project_single(im, matrix, proj[view_ind]) + ) + return proj - x = jnp.mgrid[: im.shape[0], : im.shape[1], : im.shape[2]] - # (v, 2, 3) X (3, x0, x1, x2) + (v, 2) -> (v, 2, x0, x1, x2) - Px = ( - jnp.tensordot(matrices[..., :3], x, axes=[2, 0]) - + matrices[..., 3, np.newaxis, np.newaxis, np.newaxis] - ) + @staticmethod + @partial(jax.jit, donate_argnames="proj") + def _project_single(im: ArrayLike, matrix: ArrayLike, proj: ArrayLike) -> ArrayLike: + r""" + Args: + im: Input image. + matrix: (2, 4) homogeneous projection matrix. + det_shape: Shape of detector. + """ + + x = jnp.mgrid[: im.shape[0], : im.shape[1], : im.shape[2]] # (3, ...) + + Px = jnp.stack( + ( + matrix[0, 0] * x[0] + matrix[0, 1] * x[1] + matrix[0, 2] * x[2] + matrix[0, 3], + matrix[1, 0] * x[0] + matrix[1, 1] * x[1] + matrix[1, 2] * x[2] + matrix[1, 3], + ) + ) # (2, ...) # calculate weight on 4 intersecting pixels w = 0.5 # assumed <= 1.0 left_edge = Px - w / 2 to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w) ul_ind = jnp.floor(left_edge).astype("int32") + det_shape = proj.shape ul_ind = jnp.where(ul_ind < 0, max(det_shape), ul_ind) # otherwise negative values wrap - ul_weight = to_next[:, 0] * to_next[:, 1] * (1 / w**2) - ur_weight = (w - to_next[:, 0]) * to_next[:, 1] * (1 / w**2) - ll_weight = to_next[:, 0] * (w - to_next[:, 1]) * (1 / w**2) - lr_weight = (w - to_next[:, 0]) * (w - to_next[:, 1]) * (1 / w**2) + ul_weight = to_next[0] * to_next[1] * (1 / w**2) + ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2) + ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2) + lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2) - num_views = len(matrices) - proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype) - view_ind = jnp.expand_dims(jnp.arange(num_views), range(1, 4)) - proj = proj.at[view_ind, ul_ind[:, 0], ul_ind[:, 1]].add(ul_weight * im, mode="drop") - proj = proj.at[view_ind, ul_ind[:, 0] + 1, ul_ind[:, 1]].add(ur_weight * im, mode="drop") - proj = proj.at[view_ind, ul_ind[:, 0], ul_ind[:, 1] + 1].add(ll_weight * im, mode="drop") - proj = proj.at[view_ind, ul_ind[:, 0] + 1, ul_ind[:, 1] + 1].add( - lr_weight * im, mode="drop" - ) + proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode="drop") + proj = proj.at[ul_ind[0] + 1, ul_ind[1]].add(ur_weight * im, mode="drop") + proj = proj.at[ul_ind[0], ul_ind[1] + 1].add(ll_weight * im, mode="drop") + proj = proj.at[ul_ind[0] + 1, ul_ind[1] + 1].add(lr_weight * im, mode="drop") return proj @staticmethod