Skip to content

Commit

Permalink
Fix XRayTransform boundary conditions (#561)
Browse files Browse the repository at this point in the history
* Add failing tests
* Fix boundary handling
  • Loading branch information
Michael-T-McCann authored Oct 15, 2024
1 parent 008697c commit d3193ea
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 55 deletions.
58 changes: 46 additions & 12 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,20 @@ def _project(
"""
nx = im.shape
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# avoid incompatible types in the .add (scatter operation)
weights = weights.astype(im.dtype)

# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)
y = (
jnp.zeros((len(angles), ny), dtype=im.dtype)
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
.add(im * weights_valid)
)

y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))
weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid)

return y

Expand All @@ -194,14 +194,15 @@ def _back_project(
"""
ny = y.shape[1]
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)
# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)

# the idea: [y[0, inds[0]], y[1, inds[1]], ...]
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0)
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0)

weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
HTy = HTy + jnp.sum(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)

return HTy
Expand Down Expand Up @@ -401,7 +402,7 @@ def _back_project_single(

@staticmethod
def _calc_weights(
input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0
input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0
) -> snp.Array:
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...)
Expand All @@ -419,13 +420,46 @@ def _calc_weights(
left_edge = Px - w / 2
to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w)
ul_ind = jnp.floor(left_edge).astype("int32")
ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap

ul_weight = to_next[0] * to_next[1] * (1 / w**2)
ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2)
ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2)
lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2)

# set weights to zero out of bounds
ul_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ul_weight,
0.0,
)
ur_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ur_weight,
0.0,
)
ll_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
ll_weight,
0.0,
)
lr_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
lr_weight,
0.0,
)

return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

import scico
from scico.linop.xray import XRayTransform2D, XRayTransform3D
import scico.linop
from scico.linop.xray import XRayTransform2D


@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -71,45 +72,12 @@ def test_apply_adjoint():
assert y.shape[1] == det_count


def test_3d_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)

input_shape = x.shape
output_shape = x.shape[:2]

# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)

# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger voxels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = jnp.array(
[[[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger detector pixels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = None # fmt: on # TODO: Check this case more closely.
# np.testing.assert_allclose(H @ x, truth)
def test_matched_adjoint():
"""See https://github.com/lanl/scico/issues/560."""
N = 16
det_count = int(N * 1.05 / np.sqrt(2.0))
dx = 1.0 / np.sqrt(2)
n_projection = 3
angles = np.linspace(0, np.pi, n_projection, endpoint=False)
A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)
assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)
66 changes: 66 additions & 0 deletions scico/test/linop/xray/test_xray_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np

import jax.numpy as jnp

import scico.linop
from scico.linop.xray import XRayTransform3D


def test_matched_adjoint():
"""See https://github.com/lanl/scico/issues/560."""
N = 16
det_count = int(N * 1.05 / np.sqrt(2.0))
n_projection = 3

input_shape = (N, N, N)
det_shape = (det_count, det_count)

M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False)
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)

assert scico.linop.valid_adjoint(H, H.T, eps=1e-5)


def test_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)

input_shape = x.shape
det_shape = x.shape[:2]

# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger voxels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = jnp.array(
[[[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger detector pixels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", [0.0], det_spacing=[2.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = None # fmt: on # TODO: Check this case more closely.
# np.testing.assert_allclose(H @ x, truth)

0 comments on commit d3193ea

Please sign in to comment.