Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples #540

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f3ad2b9
Minor comment edits
bwohlberg Jun 4, 2024
c2efef5
Fix missing typing import
bwohlberg Jun 5, 2024
fd012fa
Docs fixes
bwohlberg Jun 5, 2024
d5820e3
Improve docs
bwohlberg Jun 6, 2024
d518af1
Docs fixes
bwohlberg Jun 6, 2024
e01f1b6
Add missing dependency (due to recent changes in other packages) for …
bwohlberg Jun 6, 2024
9c4131b
Update submodule
bwohlberg Jun 6, 2024
e40615a
Fix script to notebook conversion bug
bwohlberg Jun 6, 2024
cbec0cf
Update submodule
Jun 6, 2024
5124e62
Update submodule
Jun 6, 2024
b64f977
Update submodule
bwohlberg Jun 6, 2024
a46200c
Update submodule
Jun 6, 2024
9c5d6a7
Add missing dependency for building notebooks
bwohlberg Jun 7, 2024
58673b3
Merge branch 'main' into brendt/examples
bwohlberg Jun 13, 2024
4c3a195
Update submodule
Jun 13, 2024
5e38ca8
Improve example script docstring
bwohlberg Jun 13, 2024
abd683d
Update submodule
bwohlberg Jun 13, 2024
80ac628
Avoid ray complaints
bwohlberg Jun 13, 2024
822cef8
Remove second ray.init call
bwohlberg Jun 13, 2024
f72adbe
Update submodule
bwohlberg Jun 17, 2024
782de2c
Merge branch 'main' into brendt/examples
bwohlberg Jun 17, 2024
6bee347
Update submodule
Jun 18, 2024
020214a
Update submodule
bwohlberg Jun 18, 2024
c1e5e07
Merge branch 'brendt/examples' of github.com:lanl/scico into brendt/e…
bwohlberg Jun 18, 2024
503b6d5
Replace deprecated matplotlib method
bwohlberg Jun 19, 2024
9917021
Merge branch 'main' into brendt/examples
bwohlberg Jul 9, 2024
e7a27a9
Update submodule
Jul 9, 2024
2829fcb
Update submodule
Jul 9, 2024
0eb31fd
Remove largely redundant example
bwohlberg Jul 9, 2024
ab90d40
Update example index
bwohlberg Jul 9, 2024
be7fb76
Rename example script
bwohlberg Jul 9, 2024
d4285d1
Edit script title and fix docs
bwohlberg Jul 9, 2024
f17ac18
Update example index
bwohlberg Jul 9, 2024
275ef2c
Update submodule
bwohlberg Jul 9, 2024
92d4dff
Merge branch 'brendt/examples' of github.com:lanl/scico into brendt/e…
bwohlberg Jul 9, 2024
803b836
Update submodule
bwohlberg Jul 19, 2024
2bf79a9
void doctest errors resulting from unimportable astra or svmbir
bwohlberg Jul 19, 2024
1c0cb42
Avoid doctest errors resulting from unimportable astra or svmbir
bwohlberg Jul 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ Computed Tomography
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
Expand Down Expand Up @@ -137,6 +136,7 @@ Total Variation
examples/ct_astra_3d_tv_admm
examples/ct_astra_3d_tv_padmm
examples/ct_astra_weighted_tv_admm
examples/ct_multi_tv_admm
examples/ct_svmbir_tv_multi
examples/deconv_circ_tv_admm
examples/deconv_tv_admm
Expand Down Expand Up @@ -208,6 +208,7 @@ ADMM
examples/ct_tv_admm
examples/ct_astra_3d_tv_admm
examples/ct_astra_weighted_tv_admm
examples/ct_multi_tv_admm
examples/ct_svmbir_tv_multi
examples/ct_svmbir_ppp_bm3d_admm_cg
examples/ct_svmbir_ppp_bm3d_admm_prox
Expand Down
9 changes: 5 additions & 4 deletions examples/jnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def py_file_to_string(src):
if import_seen:
# Once an import statement has been seen, break on encountering a line that
# is neither an import statement nor a newline, nor a component of an import
# statement extended over multiple lines, nor an os.environ statement, nor
# components of a try/except construction (note that handling of these final
# two cases is probably not very robust).
# statement extended over multiple lines, nor an os.environ statement, nor a
# ray.init statement, nor components of a try/except construction (note that
# handling of these final two cases is probably not very robust).
if not re.match(
r"(^import|^from|^\n$|^\W+[^\W]|^\)$|^os.environ|^try:$|^except)", line
r"(^import|^from|^\n$|^\W+[^\W]|^\)$|^os.environ|^ray.init|^try:$|^except)",
line,
):
lines.append(line)
break
Expand Down
2 changes: 2 additions & 0 deletions examples/notebooks_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-r examples-requirements.txt
ipykernel
ipywidgets
nbformat
nbconvert
nb_conda_kernels
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ Computed Tomography
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Transform Comparison
`ct_multi_cs_tv_admm.py <ct_multi_cs_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)

Expand Down Expand Up @@ -176,6 +174,8 @@ Total Variation
3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)
`ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_
TV-Regularized Low-Dose CT Reconstruction
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)
`ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_
TV-Regularized CT Reconstruction (Multiple Algorithms)
`deconv_circ_tv_admm.py <deconv_circ_tv_admm.py>`_
Expand Down Expand Up @@ -275,6 +275,8 @@ ADMM
3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver)
`ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_
TV-Regularized Low-Dose CT Reconstruction
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)
`ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_
TV-Regularized CT Reconstruction (Multiple Algorithms)
`ct_svmbir_ppp_bm3d_admm_cg.py <ct_svmbir_ppp_bm3d_admm_cg.py>`_
Expand Down
186 changes: 0 additions & 186 deletions examples/scripts/ct_multi_cs_tv_admm.py

This file was deleted.

73 changes: 53 additions & 20 deletions examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
where $A$ is the X-ray transform (the CT forward projection operator),
$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and
$\mathbf{x}$ is the desired image. The solution is computed and compared
for all three 2D CT projectors available in scico.
for all three 2D CT projectors available in scico, using a sinogram
computed with the astra projector.
"""

import numpy as np
Expand All @@ -37,44 +38,64 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

det_count = N
det_spacing = np.sqrt(2)


"""
Define CT geometry and construct array of (approximately) equivalent projectors.
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
projectors = {
"astra": astra.XRayTransform2D(x_gt.shape, N, 1.0, angles - np.pi / 2.0), # astra
"svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir
"scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico
"astra": astra.XRayTransform2D(
x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0
), # astra
"svmbir": svmbir.XRayTransform(
x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing
), # svmbir
"scico": XRayTransform(
Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing)
), # scico
}


"""
Compute common sinogram using astra projector.
"""
A = projectors["astra"]
noise = np.random.normal(size=(n_projection, det_count)).astype(np.float32)
y = A @ x_gt + 2.0 * noise


"""
Construct initial solution for regularized problem.
"""
x0 = A.fbp(y)


"""
Solve the same problem using the different projectors.
"""
print(f"Solving on {device_info()}")
y, x_rec, hist = {}, {}, {}
noise = np.random.normal(size=(n_projection, N)).astype(np.float32)
for p in ("astra", "svmbir", "scico"):
x_rec, hist = {}, {}
for p in projectors.keys():
print(f"\nSolving with {p} projector")
A = projectors[p]
y[p] = A @ x_gt + 2.0 * noise # sinogram

# Set up ADMM solver object.
λ = 2e0 # L1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
λ = 2e1 # L1 norm regularization parameter
ρ = 1e3 # ADMM penalty parameter
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
cg_maxiter = 25 # maximum CG iterations per ADMM iteration
cg_maxiter = 50 # maximum CG iterations per ADMM iteration

# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()
f = loss.SquaredL2Loss(y=y[p], A=A)
x0 = snp.clip(A.T(y[p]), 0, 1.0)
A = projectors[p]
f = loss.SquaredL2Loss(y=y, A=A)

# Set up the solver.
solver = ADMM(
Expand All @@ -91,15 +112,25 @@
# Run the solver.
solver.solve()
hist[p] = solver.itstat_object.history(transpose=True)
x_rec[p] = snp.clip(solver.x, 0, 1.0)
x_rec[p] = solver.x

if p == "scico":
x_rec[p] = x_rec[p] * det_spacing # to match ASTRA's scaling


"""
Compare reconstruction results.
"""
print("Reconstruction SNR:")
for p in projectors.keys():
print(f" {(p + ':'):7s} {metric.snr(x_gt, x_rec[p]):5.2f} dB")


"""
Compare sinograms.
Display sinogram.
"""
fig, ax = plot.subplots(nrows=3, ncols=1, figsize=(15, 10))
for idx, name in enumerate(projectors.keys()):
plot.imview(y[name], title=f"{name} sinogram", cbar=None, fig=fig, ax=ax[idx])
fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3))
plot.imview(y, title="sinogram", fig=fig, ax=ax)
fig.show()


Expand Down Expand Up @@ -147,6 +178,8 @@
fig=fig,
ax=ax[n + 1],
)
for ax in ax:
ax.get_images()[0].set_clim(-0.1, 1.1)
fig.show()


Expand Down
Loading
Loading