Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XRayTransform boundary conditions #561

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,20 @@ def _project(
"""
nx = im.shape
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# avoid incompatible types in the .add (scatter operation)
weights = weights.astype(im.dtype)

# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)
y = (
jnp.zeros((len(angles), ny), dtype=im.dtype)
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
.add(im * weights_valid)
)

y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))
weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid)

return y

Expand All @@ -194,14 +194,15 @@ def _back_project(
"""
ny = y.shape[1]
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)
# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)

# the idea: [y[0, inds[0]], y[1, inds[1]], ...]
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0)
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0)

weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
HTy = HTy + jnp.sum(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)

return HTy
Expand Down Expand Up @@ -401,7 +402,7 @@ def _back_project_single(

@staticmethod
def _calc_weights(
input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0
input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0
) -> snp.Array:
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...)
Expand All @@ -419,13 +420,46 @@ def _calc_weights(
left_edge = Px - w / 2
to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w)
ul_ind = jnp.floor(left_edge).astype("int32")
ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap

ul_weight = to_next[0] * to_next[1] * (1 / w**2)
ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2)
ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2)
lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2)

# set weights to zero out of bounds
ul_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ul_weight,
0.0,
)
ur_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ur_weight,
0.0,
)
ll_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
ll_weight,
0.0,
)
lr_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
lr_weight,
0.0,
)

return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

import scico
from scico.linop.xray import XRayTransform2D, XRayTransform3D
import scico.linop
from scico.linop.xray import XRayTransform2D


@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -71,45 +72,12 @@ def test_apply_adjoint():
assert y.shape[1] == det_count


def test_3d_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)

input_shape = x.shape
output_shape = x.shape[:2]

# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)

# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger voxels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = jnp.array(
[[[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger detector pixels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = None # fmt: on # TODO: Check this case more closely.
# np.testing.assert_allclose(H @ x, truth)
def test_matched_adjoint():
"""See https://github.com/lanl/scico/issues/560."""
N = 16
det_count = int(N * 1.05 / np.sqrt(2.0))
dx = 1.0 / np.sqrt(2)
n_projection = 3
angles = np.linspace(0, np.pi, n_projection, endpoint=False)
A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)
assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)
66 changes: 66 additions & 0 deletions scico/test/linop/xray/test_xray_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np

import jax.numpy as jnp

import scico.linop
from scico.linop.xray import XRayTransform3D


def test_matched_adjoint():
"""See https://github.com/lanl/scico/issues/560."""
N = 16
det_count = int(N * 1.05 / np.sqrt(2.0))
n_projection = 3

input_shape = (N, N, N)
det_shape = (det_count, det_count)

M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False)
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)

assert scico.linop.valid_adjoint(H, H.T, eps=1e-5)


def test_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)

input_shape = x.shape
det_shape = x.shape[:2]

# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger voxels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = jnp.array(
[[[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ],
[0. , 0.5, 0.5, 0. ]]]
) # fmt: on
np.testing.assert_allclose(H @ x, truth)

# bigger detector pixels in the x (first index) direction
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, det_shape, "X", [0.0], det_spacing=[2.0, 1.0]
)
H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
# fmt: off
truth = None # fmt: on # TODO: Check this case more closely.
# np.testing.assert_allclose(H @ x, truth)
Loading