Skip to content

Commit

Permalink
Merge branch 'main' into brendt/rename
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 2, 2023
2 parents 52efdcb + bcb1eab commit 48d7ba8
Show file tree
Hide file tree
Showing 73 changed files with 190 additions and 392 deletions.
2 changes: 1 addition & 1 deletion examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
astra-toolbox
colour_demosaicing
xdesign>=0.5.5
ray[tune]>=2.0.0
ray[tune,train]>=2.5.0
hyperopt
bm3d>=4.0.0
bm4d>=4.2.2
2 changes: 1 addition & 1 deletion examples/scripts/ct_abel_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
better performance than isotropic TV for this problem, is used here.
"""
f = loss.SquaredL2Loss(y=y, A=A)
λ = 2.35e1 # L1 norm regularization parameter
λ = 2.35e1 # ℓ1 norm regularization parameter
g = λ * functional.L1Norm() # Note the use of anisotropic TV
C = linop.FiniteDifference(input_shape=x_gt.shape)

Expand Down
21 changes: 10 additions & 11 deletions examples/scripts/ct_abel_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
`ray.tune` class API is used in this example.
This script is hard-coded to run on CPU only to avoid the large number of
warnings that are emitted when GPU resources are requested but not available,
and due to the difficulty of supressing these warnings in a way that does
not force use of the CPU only. To enable GPU usage, comment out the
`os.environ` statements near the beginning of the script, and change the
value of the "gpu" entry in the `resources` dict from 0 to 1. Note that
two environment variables are set to suppress the warnings because
`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change
has yet to be correctly implemented
warnings that are emitted when GPU resources are requested but not
available, and due to the difficulty of supressing these warnings in a
way that does not force use of the CPU only. To enable GPU usage, comment
out the `os.environ` statements near the beginning of the script, and
change the value of the "gpu" entry in the `resources` dict from 0 to 1.
Note that two environment variables are set to suppress the warnings
because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but
this change has yet to be correctly implemented
(see [google/jax#6805](https://github.com/google/jax/issues/6805) and
[google/jax#10272](https://github.com/google/jax/pull/10272).
"""
Expand All @@ -34,7 +34,6 @@

import numpy as np

import jax

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
Expand Down Expand Up @@ -82,8 +81,8 @@ def setup(self, config, x_gt, x0, y):
this case). The remaining parameters are objects that are passed
to the evaluation function via the ray object store.
"""
# Put main arrays on jax device.
self.x_gt, self.x0, self.y = jax.device_put([x_gt, x0, y])
# Get arrays passed by tune call.
self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)
# Set up problem to be solved.
self.A = AbelTransform(self.x_gt.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
Expand Down
11 changes: 5 additions & 6 deletions examples/scripts/ct_astra_3d_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@

import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
from scico.linop.xray.astra import XRayTransform
Expand All @@ -36,13 +35,11 @@
"""
Create a ground truth image and projector.
"""

Nx = 128
Ny = 256
Nz = 64

tangle = create_tangle_phantom(Nx, Ny, Nz)
tangle = jax.device_put(tangle)
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
Expand All @@ -53,7 +50,7 @@
"""
Set up ADMM solver object.
"""
λ = 2e0 # L1 norm regularization parameter
λ = 2e0 # ℓ2,1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
Expand All @@ -80,6 +77,7 @@
itstat_options={"display": True, "period": 5},
)


"""
Run the solver.
"""
Expand All @@ -93,6 +91,7 @@
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)


"""
Show the recovered image.
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
stats_object_ini = None

checkpoint_files = []
for (dirpath, dirnames, filenames) in os.walk(workdir2):
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"]

if len(checkpoint_files) > 0:
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/ct_astra_noreg_pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np

import jax
import jax.numpy as jnp

from xdesign import Foam, discrete_phantom
Expand All @@ -38,7 +37,7 @@
"""
N = 256 # phantom size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
x_gt = jnp.array(x_gt) # convert to jax type


"""
Expand Down
6 changes: 2 additions & 4 deletions examples/scripts/ct_astra_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
from xdesign import Foam, discrete_phantom

Expand All @@ -38,7 +36,7 @@
N = 512 # phantom size
np.random.seed(1234)
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
x_gt = snp.array(x_gt) # convert to jax type


"""
Expand All @@ -53,7 +51,7 @@
"""
Set up ADMM solver object.
"""
λ = 2e0 # L1 norm regularization parameter
λ = 2e0 # ℓ1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import numpy as np

import jax

from xdesign import Soil, discrete_phantom

import scico.numpy as snp
Expand All @@ -42,7 +40,7 @@
x_gt = discrete_phantom(Soil(porosity=0.80), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
x_gt = snp.array(x_gt) # convert to jax type


"""
Expand Down Expand Up @@ -72,7 +70,7 @@
counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))
counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1
y = -1 / 𝛼 * np.log(counts / Io)
y = jax.device_put(y) # convert back to float32
y = snp.array(y) # convert back to float32 as a jax array


"""
Expand Down Expand Up @@ -140,7 +138,7 @@ def postprocess(x):
"""
lambda_weighted = 5e1

weights = jax.device_put(counts / Io)
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))

admm_weighted = ADMM(
Expand Down
18 changes: 9 additions & 9 deletions examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

import numpy as np

import jax

import matplotlib.pyplot as plt
import svmbir
from matplotlib.ticker import MaxNLocator
Expand Down Expand Up @@ -137,10 +135,12 @@ def add_poisson_noise(sino, max_intensity):


"""
Push arrays to device.
Convert numpy arrays to jax arrays.
"""
y_fan, x0_fan, weights_fan = jax.device_put([y_fan, x_mrf_fan, weights_fan])
x0_parallel = jax.device_put(x_mrf_parallel)
y_fan = snp.array(y_fan)
x0_fan = snp.array(x_mrf_fan)
weights_fan = snp.array(weights_fan)
x0_parallel = snp.array(x_mrf_parallel)


"""
Expand Down Expand Up @@ -179,7 +179,7 @@ def add_poisson_noise(sino, max_intensity):
x0=x0_fan,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
itstat_options={"display": True, "period": 5},
)
solver_extloss_parallel = ADMM(
f=None,
Expand All @@ -189,7 +189,7 @@ def add_poisson_noise(sino, max_intensity):
x0=x0_parallel,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
itstat_options={"display": True, "period": 5},
)


Expand Down Expand Up @@ -267,7 +267,7 @@ def add_poisson_noise(sino, max_intensity):
fig=fig,
ax=ax[0],
)
ax[0].set_ylim([5e-3, 1e0])
ax[0].set_ylim([5e-3, 5e0])
ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
plot.plot(
snp.vstack((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T,
Expand All @@ -278,7 +278,7 @@ def add_poisson_noise(sino, max_intensity):
fig=fig,
ax=ax[1],
)
ax[1].set_ylim([5e-3, 1e0])
ax[1].set_ylim([5e-3, 5e0])
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
fig.show()

Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ct_projector_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt)
x_gt = jnp.array(x_gt)


"""
Expand All @@ -44,7 +44,6 @@
num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)


timer = Timer()

projectors = {}
Expand Down Expand Up @@ -92,7 +91,7 @@
"""
y = np.zeros(H.output_shape, dtype=np.float32)
y[num_angles // 3, det_count // 2] = 1.0
y = jax.device_put(y)
y = jnp.array(y)

HTys = {}
for name, H in projectors.items():
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import numpy as np

import jax

import matplotlib.pyplot as plt
import svmbir
from xdesign import Foam, discrete_phantom
Expand Down Expand Up @@ -88,7 +86,9 @@
"""
Set up an ADMM solver.
"""
y, x0, weights = jax.device_put([y, x_mrf, weights])
y = snp.array(y)
x0 = snp.array(x_mrf)
weights = snp.array(weights)

ρ = 15 # ADMM penalty parameter
σ = density * 0.18 # denoiser sigma
Expand Down
16 changes: 8 additions & 8 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

import numpy as np

import jax

import matplotlib.pyplot as plt
import svmbir
from matplotlib.ticker import MaxNLocator
Expand Down Expand Up @@ -100,9 +98,11 @@


"""
Push arrays to device.
Convert numpy arrays to jax arrays.
"""
y, x0, weights = jax.device_put([y, x_mrf, weights])
y = snp.array(y)
x0 = snp.array(x_mrf)
weights = snp.array(weights)


"""
Expand All @@ -129,7 +129,7 @@
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
itstat_options={"display": True, "period": 5},
)


Expand Down Expand Up @@ -161,7 +161,7 @@
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
itstat_options={"display": True, "period": 5},
)


Expand Down Expand Up @@ -219,7 +219,7 @@
fig=fig,
ax=ax[0],
)
ax[0].set_ylim([5e-3, 1e0])
ax[0].set_ylim([5e-3, 5e0])
ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
plot.plot(
snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,
Expand All @@ -230,7 +230,7 @@
fig=fig,
ax=ax[1],
)
ax[1].set_ylim([5e-3, 1e0])
ax[1].set_ylim([5e-3, 5e0])
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
fig.show()

Expand Down
9 changes: 4 additions & 5 deletions examples/scripts/ct_svmbir_tv_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import numpy as np

import jax

import matplotlib.pyplot as plt
import svmbir
from xdesign import Foam, discrete_phantom
Expand Down Expand Up @@ -65,7 +63,7 @@
expected_counts = max_intensity * np.exp(-sino)
noisy_counts = np.random.poisson(expected_counts).astype(np.float32)
noisy_counts[noisy_counts == 0] = 1 # deal with 0s
y = -np.log(noisy_counts / max_intensity)
y = -snp.log(noisy_counts / max_intensity)


"""
Expand All @@ -87,9 +85,10 @@
"""
Set up problem.
"""
y, x0, weights = jax.device_put([y, x_mrf, weights])
x0 = snp.array(x_mrf)
weights = snp.array(weights)

λ = 1e-1 # L1 norm regularization parameter
λ = 1e-1 # ℓ1 norm regularization parameter

f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g = λ * functional.L21Norm() # regularization functional
Expand Down
Loading

0 comments on commit 48d7ba8

Please sign in to comment.