diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index 60b77e7e7..ff7a1dcfb 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,7 +1,7 @@ astra-toolbox colour_demosaicing xdesign>=0.5.5 -ray[tune]>=2.0.0 +ray[tune,train]>=2.5.0 hyperopt bm3d>=4.0.0 bm4d>=4.2.2 diff --git a/examples/scripts/ct_abel_tv_admm.py b/examples/scripts/ct_abel_tv_admm.py index 2abf2c4d1..2adc141ce 100644 --- a/examples/scripts/ct_abel_tv_admm.py +++ b/examples/scripts/ct_abel_tv_admm.py @@ -57,7 +57,7 @@ better performance than isotropic TV for this problem, is used here. """ f = loss.SquaredL2Loss(y=y, A=A) -λ = 2.35e1 # L1 norm regularization parameter +λ = 2.35e1 # ℓ1 norm regularization parameter g = λ * functional.L1Norm() # Note the use of anisotropic TV C = linop.FiniteDifference(input_shape=x_gt.shape) diff --git a/examples/scripts/ct_abel_tv_admm_tune.py b/examples/scripts/ct_abel_tv_admm_tune.py index 5f4fc68f9..c60ade412 100644 --- a/examples/scripts/ct_abel_tv_admm_tune.py +++ b/examples/scripts/ct_abel_tv_admm_tune.py @@ -14,14 +14,14 @@ `ray.tune` class API is used in this example. This script is hard-coded to run on CPU only to avoid the large number of -warnings that are emitted when GPU resources are requested but not available, -and due to the difficulty of supressing these warnings in a way that does -not force use of the CPU only. To enable GPU usage, comment out the -`os.environ` statements near the beginning of the script, and change the -value of the "gpu" entry in the `resources` dict from 0 to 1. Note that -two environment variables are set to suppress the warnings because -`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change -has yet to be correctly implemented +warnings that are emitted when GPU resources are requested but not +available, and due to the difficulty of supressing these warnings in a +way that does not force use of the CPU only. To enable GPU usage, comment +out the `os.environ` statements near the beginning of the script, and +change the value of the "gpu" entry in the `resources` dict from 0 to 1. +Note that two environment variables are set to suppress the warnings +because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but +this change has yet to be correctly implemented (see [google/jax#6805](https://github.com/google/jax/issues/6805) and [google/jax#10272](https://github.com/google/jax/pull/10272). """ @@ -34,7 +34,6 @@ import numpy as np -import jax import scico.numpy as snp from scico import functional, linop, loss, metric, plot @@ -82,8 +81,8 @@ def setup(self, config, x_gt, x0, y): this case). The remaining parameters are objects that are passed to the evaluation function via the ray object store. """ - # Put main arrays on jax device. - self.x_gt, self.x0, self.y = jax.device_put([x_gt, x0, y]) + # Get arrays passed by tune call. + self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y) # Set up problem to be solved. self.A = AbelTransform(self.x_gt.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py index 8ccbe65a7..bb64ea61b 100644 --- a/examples/scripts/ct_astra_3d_tv_admm.py +++ b/examples/scripts/ct_astra_3d_tv_admm.py @@ -23,10 +23,9 @@ import numpy as np -import jax - from mpl_toolkits.axes_grid1 import make_axes_locatable +import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_tangle_phantom from scico.linop.xray.astra import XRayTransform @@ -36,13 +35,11 @@ """ Create a ground truth image and projector. """ - Nx = 128 Ny = 256 Nz = 64 -tangle = create_tangle_phantom(Nx, Ny, Nz) -tangle = jax.device_put(tangle) +tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles @@ -53,7 +50,7 @@ """ Set up ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ2,1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance @@ -80,6 +77,7 @@ itstat_options={"display": True, "period": 5}, ) + """ Run the solver. """ @@ -93,6 +91,7 @@ % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) ) + """ Show the recovered image. """ diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index 8d1f46ee8..66a137e9c 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -168,7 +168,7 @@ stats_object_ini = None checkpoint_files = [] -for (dirpath, dirnames, filenames) in os.walk(workdir2): +for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"] if len(checkpoint_files) > 0: diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py index ed9cf3b04..9e78f59fd 100644 --- a/examples/scripts/ct_astra_noreg_pcg.py +++ b/examples/scripts/ct_astra_noreg_pcg.py @@ -23,7 +23,6 @@ import numpy as np -import jax import jax.numpy as jnp from xdesign import Foam, discrete_phantom @@ -38,7 +37,7 @@ """ N = 256 # phantom size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = jnp.array(x_gt) # convert to jax type """ diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index 4612c0911..69520f872 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -21,8 +21,6 @@ import numpy as np -import jax - from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import Foam, discrete_phantom @@ -38,7 +36,7 @@ N = 512 # phantom size np.random.seed(1234) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = snp.array(x_gt) # convert to jax type """ @@ -53,7 +51,7 @@ """ Set up ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index 319d7bc96..b3dc439c2 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -23,8 +23,6 @@ import numpy as np -import jax - from xdesign import Soil, discrete_phantom import scico.numpy as snp @@ -42,7 +40,7 @@ x_gt = discrete_phantom(Soil(porosity=0.80), size=384) x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64))) x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = snp.array(x_gt) # convert to jax type """ @@ -72,7 +70,7 @@ counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt)) counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1 y = -1 / 𝛼 * np.log(counts / Io) -y = jax.device_put(y) # convert back to float32 +y = snp.array(y) # convert back to float32 as a jax array """ @@ -140,7 +138,7 @@ def postprocess(x): """ lambda_weighted = 5e1 -weights = jax.device_put(counts / Io) +weights = snp.array(counts / Io) f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) admm_weighted = ADMM( diff --git a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py index 9d47b3e31..80299a1ae 100644 --- a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py @@ -26,8 +26,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator @@ -137,10 +135,12 @@ def add_poisson_noise(sino, max_intensity): """ -Push arrays to device. +Convert numpy arrays to jax arrays. """ -y_fan, x0_fan, weights_fan = jax.device_put([y_fan, x_mrf_fan, weights_fan]) -x0_parallel = jax.device_put(x_mrf_parallel) +y_fan = snp.array(y_fan) +x0_fan = snp.array(x_mrf_fan) +weights_fan = snp.array(weights_fan) +x0_parallel = snp.array(x_mrf_parallel) """ @@ -179,7 +179,7 @@ def add_poisson_noise(sino, max_intensity): x0=x0_fan, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) solver_extloss_parallel = ADMM( f=None, @@ -189,7 +189,7 @@ def add_poisson_noise(sino, max_intensity): x0=x0_parallel, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -267,7 +267,7 @@ def add_poisson_noise(sino, max_intensity): fig=fig, ax=ax[0], ) -ax[0].set_ylim([5e-3, 1e0]) +ax[0].set_ylim([5e-3, 5e0]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.vstack((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T, @@ -278,7 +278,7 @@ def add_poisson_noise(sino, max_intensity): fig=fig, ax=ax[1], ) -ax[1].set_ylim([5e-3, 1e0]) +ax[1].set_ylim([5e-3, 5e0]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index 1027e6c8a..58a31d4cd 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -34,7 +34,7 @@ det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) +x_gt = jnp.array(x_gt) """ @@ -44,7 +44,6 @@ num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) - timer = Timer() projectors = {} @@ -92,7 +91,7 @@ """ y = np.zeros(H.output_shape, dtype=np.float32) y[num_angles // 3, det_count // 2] = 1.0 -y = jax.device_put(y) +y = jnp.array(y) HTys = {} for name, H in projectors.items(): diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index 168e6252d..390925d11 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -23,8 +23,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom @@ -88,7 +86,9 @@ """ Set up an ADMM solver. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +y = snp.array(y) +x0 = snp.array(x_mrf) +weights = snp.array(weights) ρ = 15 # ADMM penalty parameter σ = density * 0.18 # denoiser sigma diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index fef15119f..a6e663a09 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -32,8 +32,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator @@ -100,9 +98,11 @@ """ -Push arrays to device. +Convert numpy arrays to jax arrays. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +y = snp.array(y) +x0 = snp.array(x_mrf) +weights = snp.array(weights) """ @@ -129,7 +129,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -161,7 +161,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -219,7 +219,7 @@ fig=fig, ax=ax[0], ) -ax[0].set_ylim([5e-3, 1e0]) +ax[0].set_ylim([5e-3, 5e0]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T, @@ -230,7 +230,7 @@ fig=fig, ax=ax[1], ) -ax[1].set_ylim([5e-3, 1e0]) +ax[1].set_ylim([5e-3, 5e0]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 2a2b3e230..8592b44ff 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -22,8 +22,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom @@ -65,7 +63,7 @@ expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s -y = -np.log(noisy_counts / max_intensity) +y = -snp.log(noisy_counts / max_intensity) """ @@ -87,9 +85,10 @@ """ Set up problem. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +x0 = snp.array(x_mrf) +weights = snp.array(weights) -λ = 1e-1 # L1 norm regularization parameter +λ = 1e-1 # ℓ1 norm regularization parameter f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = λ * functional.L21Norm() # regularization functional diff --git a/examples/scripts/deconv_circ_tv_admm.py b/examples/scripts/deconv_circ_tv_admm.py index d4676a4b6..b2ba83202 100644 --- a/examples/scripts/deconv_circ_tv_admm.py +++ b/examples/scripts/deconv_circ_tv_admm.py @@ -20,8 +20,6 @@ """ -import jax - from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp @@ -36,7 +34,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/deconv_microscopy_allchn_tv_admm.py b/examples/scripts/deconv_microscopy_allchn_tv_admm.py index 244157f4f..13ec27b87 100644 --- a/examples/scripts/deconv_microscopy_allchn_tv_admm.py +++ b/examples/scripts/deconv_microscopy_allchn_tv_admm.py @@ -30,8 +30,6 @@ import numpy as np -import jax - import ray import scico.numpy as snp from scico import functional, linop, loss, plot @@ -108,9 +106,9 @@ @ray.remote(num_cpus=ncpu, num_gpus=ngpu) def deconvolve_channel(channel): """Deconvolve a single channel.""" - y_pad = jax.device_put(ray.get(y_pad_list)[channel]) - psf = jax.device_put(ray.get(psf_list)[channel]) - mask = jax.device_put(ray.get(mask_store)) + y_pad = ray.get(y_pad_list)[channel] + psf = ray.get(psf_list)[channel] + mask = ray.get(mask_store) M = linop.Diagonal(mask) C0 = linop.CircularConvolve( h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5 # forward operator diff --git a/examples/scripts/deconv_microscopy_tv_admm.py b/examples/scripts/deconv_microscopy_tv_admm.py index 917f560fb..eee4d8f42 100644 --- a/examples/scripts/deconv_microscopy_tv_admm.py +++ b/examples/scripts/deconv_microscopy_tv_admm.py @@ -52,7 +52,6 @@ y -= y.min() y /= y.max() - psf /= psf.sum() diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index b0d437eab..96d0b7f15 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -163,7 +163,7 @@ stats_object_ini = None checkpoint_files = [] -for (dirpath, dirnames, filenames) in os.walk(workdir2): +for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"] if len(checkpoint_files) > 0: diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index c96bf9019..38748d806 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -108,10 +108,10 @@ """ Define configuration dictionary for model and training loop. -Parameters have been selected for demonstration purposes and -relatively short training. The model depth is akin to the number of -unrolled iterations in the ODP model. The block depth controls the number -of layers at each unrolled iteration. The number of filters is uniform +Parameters have been selected for demonstration purposes and relatively +short training. The model depth is akin to the number of unrolled +iterations in the ODP model. The block depth controls the number of +layers at each unrolled iteration. The number of filters is uniform throughout the iterations. Better performance may be obtained by increasing depth, block depth, number of filters or training epochs, but may require longer training times. diff --git a/examples/scripts/deconv_ppp_bm3d_admm.py b/examples/scripts/deconv_ppp_bm3d_admm.py index e1c9a7322..ffebfd546 100644 --- a/examples/scripts/deconv_ppp_bm3d_admm.py +++ b/examples/scripts/deconv_ppp_bm3d_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_bm3d_pgm.py b/examples/scripts/deconv_ppp_bm3d_pgm.py index 3448b308d..f8fae40b6 100644 --- a/examples/scripts/deconv_ppp_bm3d_pgm.py +++ b/examples/scripts/deconv_ppp_bm3d_pgm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_bm4d_admm.py b/examples/scripts/deconv_ppp_bm4d_admm.py index 469b468a8..af600b6f9 100644 --- a/examples/scripts/deconv_ppp_bm4d_admm.py +++ b/examples/scripts/deconv_ppp_bm4d_admm.py @@ -17,8 +17,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.examples import create_3d_foam_phantom, downsample_volume, tile_volume_slices @@ -34,7 +32,7 @@ upsamp = 2 x_gt_hires = create_3d_foam_phantom((upsamp * Nz, upsamp * Ny, upsamp * Nx), N_sphere=100) x_gt = downsample_volume(x_gt_hires, upsamp) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_dncnn_admm.py b/examples/scripts/deconv_ppp_dncnn_admm.py index fd6e36900..d131d0300 100644 --- a/examples/scripts/deconv_ppp_dncnn_admm.py +++ b/examples/scripts/deconv_ppp_dncnn_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_dncnn_padmm.py b/examples/scripts/deconv_ppp_dncnn_padmm.py index 1707158b9..a91fbb75b 100644 --- a/examples/scripts/deconv_ppp_dncnn_padmm.py +++ b/examples/scripts/deconv_ppp_dncnn_padmm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_tv_admm.py b/examples/scripts/deconv_tv_admm.py index 5d651bd9d..874b70e3c 100644 --- a/examples/scripts/deconv_tv_admm.py +++ b/examples/scripts/deconv_tv_admm.py @@ -22,7 +22,6 @@ ADMM is used in a [companion example](deconv_tv_padmm.rst). """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/deconv_tv_admm_tune.py b/examples/scripts/deconv_tv_admm_tune.py index 809029042..a313a1028 100644 --- a/examples/scripts/deconv_tv_admm_tune.py +++ b/examples/scripts/deconv_tv_admm_tune.py @@ -32,7 +32,6 @@ os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORMS"] = "cpu" -import jax from xdesign import SiemensStar, discrete_phantom @@ -80,8 +79,6 @@ def eval_params(config, x_gt, psf, y): """ # Extract solver parameters from config dict. λ, ρ = config["lambda"], config["rho"] - # Put main arrays on jax device. - x_gt, psf, y = jax.device_put([x_gt, psf, y]) # Set up problem to be solved. A = linop.Convolve(h=psf, input_shape=x_gt.shape) f = loss.SquaredL2Loss(y=y, A=A) diff --git a/examples/scripts/deconv_tv_padmm.py b/examples/scripts/deconv_tv_padmm.py index 8cab50330..af7c4718b 100644 --- a/examples/scripts/deconv_tv_padmm.py +++ b/examples/scripts/deconv_tv_padmm.py @@ -22,7 +22,6 @@ ADMM is used in a [companion example](deconv_tv_admm.rst). """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ @@ -92,7 +90,7 @@ """ f = functional.ZeroFunctional() g0 = loss.SquaredL2Loss(y=y) -λ = 2.0e-2 # L1 norm regularization parameter +λ = 2.0e-2 # ℓ2,1 norm regularization parameter g1 = λ * functional.L21Norm() g = functional.SeparableFunctional((g0, g1)) diff --git a/examples/scripts/demosaic_ppp_bm3d_admm.py b/examples/scripts/demosaic_ppp_bm3d_admm.py index 953b67e88..15f49a50d 100644 --- a/examples/scripts/demosaic_ppp_bm3d_admm.py +++ b/examples/scripts/demosaic_ppp_bm3d_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from bm3d import bm3d_rgb from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 @@ -32,8 +30,7 @@ """ Read a ground truth image. """ -img = kodim23(asfloat=True)[160:416, 60:316] -img = jax.device_put(img) # convert to jax type, push to GPU +img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ @@ -93,7 +90,7 @@ def demosaic(cfaimg): """ Compute a baseline demosaicing solution. """ -imgb = jax.device_put(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)) +imgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)) """ diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index c92dbf9d7..4fd1330a2 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -129,6 +129,7 @@ time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) + """ Compare trained model in terms of reconstruction time and data fidelity. """ diff --git a/examples/scripts/denoise_dncnn_universal.py b/examples/scripts/denoise_dncnn_universal.py index 406baffaf..5799053cb 100644 --- a/examples/scripts/denoise_dncnn_universal.py +++ b/examples/scripts/denoise_dncnn_universal.py @@ -23,10 +23,9 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom +import scico.numpy as snp import scico.random from scico import metric, plot from scico.denoiser import DnCNN @@ -37,7 +36,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ @@ -47,7 +46,6 @@ for σ in [0.06, 0.10, 0.20]: print("------+---------+-------------------------+-------------------------") for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]: - # Instantiate a DnCNN. denoiser = DnCNN(variant=variant) diff --git a/examples/scripts/denoise_l1tv_admm.py b/examples/scripts/denoise_l1tv_admm.py index d69168418..457bec0da 100644 --- a/examples/scripts/denoise_l1tv_admm.py +++ b/examples/scripts/denoise_l1tv_admm.py @@ -20,7 +20,6 @@ operator, and $\mathbf{x}$ is the denoised image. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -39,7 +38,6 @@ phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = 0.5 * x_gt / x_gt.max() -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU y = spnoise(x_gt, 0.5) diff --git a/examples/scripts/denoise_tv_admm.py b/examples/scripts/denoise_tv_admm.py index 3d9b211a3..8e35c3e96 100644 --- a/examples/scripts/denoise_tv_admm.py +++ b/examples/scripts/denoise_tv_admm.py @@ -25,7 +25,6 @@ edges that are not vertical or horizontal. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -41,7 +40,6 @@ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU x_gt = x_gt / x_gt.max() diff --git a/examples/scripts/denoise_tv_multi.py b/examples/scripts/denoise_tv_multi.py index 3e126c961..05663a926 100644 --- a/examples/scripts/denoise_tv_multi.py +++ b/examples/scripts/denoise_tv_multi.py @@ -20,7 +20,6 @@ vectors at each point in the image $\mathbf{x}$. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -37,7 +36,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/denoise_tv_pgm.py b/examples/scripts/denoise_tv_pgm.py index d39b3ab33..fbdf6f017 100644 --- a/examples/scripts/denoise_tv_pgm.py +++ b/examples/scripts/denoise_tv_pgm.py @@ -29,7 +29,6 @@ from typing import Callable, Optional, Union -import jax import jax.numpy as jnp from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ import scico.random from scico import functional, linop, loss, operator, plot from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.util import device_info @@ -48,7 +46,6 @@ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU x_gt = x_gt / x_gt.max() @@ -88,13 +85,11 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, lmbda: float = 0.5, ): - y = ensure_on_device(y) self.functional = functional.SquaredL2Norm() super().__init__(y=y, A=A, scale=1.0) self.lmbda = lmbda def __call__(self, x: Union[Array, BlockArray]) -> float: - xint = self.y - self.lmbda * self.A(x) return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint) @@ -107,7 +102,6 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: # Evaluation of functional set to zero. class IsoProjector(functional.Functional): - has_eval = True has_prox = True @@ -160,7 +154,6 @@ def prox(self, v: Array, lam: float, **kwargs) -> Array: # Evaluation of functional set to zero. class AnisoProjector(functional.Functional): - has_eval = True has_prox = True @@ -168,7 +161,6 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: return 0.0 def prox(self, v: Array, lam: float, **kwargs) -> Array: - return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v)) diff --git a/examples/scripts/diffusercam_tv_admm.py b/examples/scripts/diffusercam_tv_admm.py index 780b3ee78..103286750 100644 --- a/examples/scripts/diffusercam_tv_admm.py +++ b/examples/scripts/diffusercam_tv_admm.py @@ -41,8 +41,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import ucb_diffusercam_data @@ -88,8 +86,8 @@ `JAX_ENABLE_X64=True` and change `dtype` below to `np.float64`. """ dtype = np.float32 -y = jax.device_put(y.astype(dtype)) -psf = jax.device_put(psf.astype(dtype)) +y = snp.array(y.astype(dtype)) +psf = snp.array(psf.astype(dtype)) """ diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_admm.py index 32f3eafde..829ccbff1 100644 --- a/examples/scripts/sparsecode_admm.py +++ b/examples/scripts/sparsecode_admm.py @@ -21,8 +21,7 @@ import numpy as np -import jax - +import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.admm import ADMM, MatrixSubproblemSolver from scico.util import device_info @@ -45,8 +44,8 @@ xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal -xt = jax.device_put(xt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU +xt = snp.array(xt) # convert to jax array +y = snp.array(y) # convert to jax array """ diff --git a/examples/scripts/sparsecode_conv_admm.py b/examples/scripts/sparsecode_conv_admm.py index e1624bdc3..611e7238b 100644 --- a/examples/scripts/sparsecode_conv_admm.py +++ b/examples/scripts/sparsecode_conv_admm.py @@ -24,8 +24,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom @@ -55,8 +53,8 @@ """ Convert numpy arrays to jax arrays. """ -h = jax.device_put(h) -x0 = jax.device_put(x0) +h = snp.array(h) +x0 = snp.array(x0) """ diff --git a/examples/scripts/sparsecode_conv_md_admm.py b/examples/scripts/sparsecode_conv_md_admm.py index 39587ab03..9f70e9ce3 100644 --- a/examples/scripts/sparsecode_conv_md_admm.py +++ b/examples/scripts/sparsecode_conv_md_admm.py @@ -36,8 +36,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom @@ -67,8 +65,8 @@ """ Convert numpy arrays to jax arrays. """ -h = jax.device_put(h) -x0 = jax.device_put(x0) +h = snp.array(h) +x0 = snp.array(x0) """ diff --git a/examples/scripts/sparsecode_pgm.py b/examples/scripts/sparsecode_pgm.py index 8cf0874c1..f5d34e08e 100644 --- a/examples/scripts/sparsecode_pgm.py +++ b/examples/scripts/sparsecode_pgm.py @@ -19,8 +19,7 @@ import numpy as np -import jax - +import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.pgm import AcceleratedPGM from scico.util import device_info @@ -44,8 +43,8 @@ x_gt[idx[0:s]] = np.random.randn(s) y = D @ x_gt + σ * np.random.randn(m) # synthetic signal -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array +y = snp.array(y) # convert to jax array """ diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_pgm.py index b9f38b1b8..08cccfed1 100644 --- a/examples/scripts/sparsecode_poisson_pgm.py +++ b/examples/scripts/sparsecode_poisson_pgm.py @@ -32,7 +32,6 @@ $\mathbf{x}^{(0)}$. """ -import jax import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt @@ -68,6 +67,7 @@ D0 = D[:, :n0] D1 = D[:, n0:] + # Define composed operator. class ForwardOperator(Operator): @@ -80,7 +80,6 @@ class ForwardOperator(Operator): """ def __init__(self, input_shape: Shape, D0, D1, jit: bool = True): - self.D0 = D0 self.D1 = D1 @@ -105,9 +104,6 @@ def _eval(self, x: BlockArray) -> BlockArray: lam = A(x_gt) y, key = scico.random.poisson(lam, shape=lam.shape, key=key) # synthetic signal -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU - """ Set up the loss function and the regularization. @@ -120,11 +116,10 @@ def _eval(self, x: BlockArray) -> BlockArray: """ -Define common setup: maximum of iterations and initial estimation of solution. +Define common setup: maximum of iterations and initial estimate of solution. """ maxiter = 50 x0, key = scico.random.uniform(((n0,), (n1,)), key=key) -x0 = jax.device_put(x0) # Initial solution estimate """ @@ -329,7 +324,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") -x = solver.solve() # Run the solver. +x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) diff --git a/examples/scripts/superres_ppp_dncnn_admm.py b/examples/scripts/superres_ppp_dncnn_admm.py index fd06cd64a..8132376ec 100644 --- a/examples/scripts/superres_ppp_dncnn_admm.py +++ b/examples/scripts/superres_ppp_dncnn_admm.py @@ -14,7 +14,6 @@ superresolution problem. """ -import jax import scico import scico.numpy as snp @@ -39,8 +38,7 @@ def downsample_image(img, rate): """ Read a ground truth image. """ -img = kodim23(asfloat=True)[160:416, 60:316] -img = jax.device_put(img) +img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ diff --git a/scico/flax/train/input_pipeline.py b/scico/flax/train/input_pipeline.py index abb966233..757ac7a5e 100644 --- a/scico/flax/train/input_pipeline.py +++ b/scico/flax/train/input_pipeline.py @@ -26,7 +26,7 @@ from .typed_dict import DataSetDict DType = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] class IterateData: diff --git a/scico/flax/train/state.py b/scico/flax/train/state.py index 21dab6ec0..9c6952cb4 100644 --- a/scico/flax/train/state.py +++ b/scico/flax/train/state.py @@ -21,7 +21,7 @@ from .typed_dict import ConfigDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any ArrayTree = optax.Params diff --git a/scico/flax/train/steps.py b/scico/flax/train/steps.py index 8b3df81ac..8901e1881 100644 --- a/scico/flax/train/steps.py +++ b/scico/flax/train/steps.py @@ -19,7 +19,7 @@ from .state import TrainState from .typed_dict import DataSetDict, MetricsDict -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any diff --git a/scico/flax/train/trainer.py b/scico/flax/train/trainer.py index 8ce144214..93cf97b84 100644 --- a/scico/flax/train/trainer.py +++ b/scico/flax/train/trainer.py @@ -47,10 +47,11 @@ from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any DType = Any + # sync across replicas def sync_batch_stats(state: TrainState) -> TrainState: """Sync the batch statistics across replicas.""" diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index c14363c0d..01f8789de 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -17,13 +17,10 @@ import numpy as np -import jax -import jax.numpy as jnp from jax.dtypes import result_type from jax.scipy.signal import convolve import scico.numpy as snp -from scico.numpy.util import ensure_on_device from scico.typing import DType, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -65,7 +62,7 @@ def __init__( if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}.") - self.h = ensure_on_device(h) + self.h = h if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") @@ -193,12 +190,10 @@ def __init__( if x.ndim != len(input_shape): raise ValueError(f"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.") - if isinstance(x, jnp.ndarray): - self.x = x - elif isinstance(x, np.ndarray): # TODO: this should not be handled at the LinOp level - self.x = jax.device_put(x) - else: + # Ensure that x is a numpy or jax array. + if not snp.util.is_arraylike(x): raise TypeError(f"Expected numpy or jax array, got {type(x)}.") + self.x = x if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index 0e9251f57..0e14a0068 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -17,7 +17,7 @@ import scico.numpy as snp from scico.numpy import Array, BlockArray -from scico.numpy.util import broadcast_nested_shapes, ensure_on_device, is_nested +from scico.numpy.util import broadcast_nested_shapes, is_nested from scico.operator._operator import _wrap_mul_div_scalar from scico.typing import BlockShape, DType, Shape @@ -48,8 +48,7 @@ def __init__( input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. """ - - self.diagonal = ensure_on_device(diagonal) + self.diagonal = diagonal if input_shape is None: input_shape = self.diagonal.shape diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 951c6957e..662df0e49 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -79,10 +79,10 @@ def __init__(self, A: ArrayLike, input_cols: int = 0): """ self.A: snp.Array #: Dense array implementing this matrix - # if A is an ndarray, make sure it gets converted to a jax array + # Ensure that A is a numpy or jax array. if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") - self.A = jnp.array(A) + self.A = A # Can only do rank-2 arrays if A.ndim != 2: diff --git a/scico/linop/abel.py b/scico/linop/abel.py index 23c3d3d48..6aa2846ca 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -127,7 +127,7 @@ def _pyabel_daun_get_proj_matrix(img_shape: Shape) -> jax.Array: direction="forward", verbose=False, ) - return jax.device_put(proj_matrix) + return jnp.array(proj_matrix) # Read abel.tools.symmetry module into a string. diff --git a/scico/linop/optics.py b/scico/linop/optics.py index b508a7230..f2a09799d 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -56,8 +56,6 @@ import numpy as np from numpy.lib.scimath import sqrt # complex sqrt -import jax - from typing_extensions import TypeGuard import scico.numpy as snp @@ -289,9 +287,7 @@ def __init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) - self.phase = jax.device_put( - np.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64) - ) + self.phase = snp.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() @@ -386,9 +382,7 @@ def __init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) - self.phase = jax.device_put( - np.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64) - ) + self.phase = snp.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() diff --git a/scico/linop/xray/svmbir.py b/scico/linop/xray/svmbir.py index 2175184a5..8e757da84 100644 --- a/scico/linop/xray/svmbir.py +++ b/scico/linop/xray/svmbir.py @@ -214,7 +214,7 @@ def _proj( delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: - return jax.device_put( + return snp.array( svmbir.project( np.array(x), np.array(angles), @@ -264,8 +264,8 @@ def _bproj( magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, - ): - return jax.device_put( + ) -> snp.Array: + return snp.array( svmbir.backproject( np.array(y), np.array(angles), @@ -432,7 +432,7 @@ def prox(self, v: snp.Array, lam: float = 1, **kwargs) -> snp.Array: if np.sum(np.isnan(result)): raise ValueError("Result contains NaNs.") - return jax.device_put(result.reshape(self.A.input_shape)) + return snp.array(result.reshape(self.A.input_shape)) class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): diff --git a/scico/loss.py b/scico/loss.py index 0e3cb9964..1a5970f1a 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico import functional, linop, operator from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device, no_nan_divide +from scico.numpy.util import no_nan_divide from scico.scipy.special import gammaln # type: ignore from scico.solver import cg @@ -67,7 +67,7 @@ def __init__( be defined in a derived class. scale: Scaling parameter. Default: 1.0. """ - self.y = ensure_on_device(y) + self.y = y if A is None: # y and x must have same shape A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype) # type: ignore @@ -166,7 +166,6 @@ def __init__( W: Optional[linop.Diagonal] = None, prox_kwargs: Optional[dict] = None, ): - r""" Args: y: Measurement. @@ -175,8 +174,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - self.W: linop.Diagonal if W is None: @@ -289,7 +286,6 @@ def __init__( `self.A` is a :class:`.Identity`. scale: Scaling parameter. Default: 0.5. """ - y = ensure_on_device(y) super().__init__(y=y, A=A, scale=scale) #: Constant term, :math:`\ln(y!)`, in Poisson log likehood. @@ -326,7 +322,6 @@ def __init__( scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): - r""" Args: y: Measurement. @@ -335,8 +330,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): @@ -533,7 +526,6 @@ def __init__( scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): - r""" Args: y: Measurement. @@ -542,8 +534,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 50fefdd4e..54a1c497d 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -10,7 +10,6 @@ from __future__ import annotations -import warnings from math import prod from typing import Any, List, Optional, Tuple, Union @@ -24,45 +23,6 @@ from ._blockarray import BlockArray -def ensure_on_device( - *arrays: Union[np.ndarray, snp.Array, BlockArray] -) -> Union[snp.Array, BlockArray]: - """Cast numpy arrays to jax arrays. - - Cast numpy arrays to jax arrays and leave jax arrays and BlockArrays, - as they are. This is intended to be used when initializing optimizers - and functionals so that all arrays are either jax arrays or - BlockArrays. - - Args: - *arrays: One or more input arrays (numpy array, jax array, or - BlockArray). - - Returns: - Array or arrays, modified where appropriate. - - Raises: - TypeError: If the arrays contain anything that is neither - numpy array, jax array, nor BlockArray. - """ - arrays = list(arrays) - - for i, array in enumerate(arrays): - if isinstance(array, np.ndarray): - warnings.warn( - f"Argument {i+1} of {len(arrays)} is a numpy array. " - "Will cast it to a jax array. " - f"To suppress this warning cast all numpy arrays to jax arrays.", - stacklevel=2, - ) - - arrays[i] = jax.device_put(arrays[i]) - - if len(arrays) == 1: - return arrays[0] - return arrays - - def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None ) -> List[int]: diff --git a/scico/optimize/_admm.py b/scico/optimize/_admm.py index 55862f0cb..eccc3acd1 100644 --- a/scico/optimize/_admm.py +++ b/scico/optimize/_admm.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from ._admmaux import ( FBlockCircularConvolveSolver, @@ -142,7 +141,7 @@ def __init__( input_shape = C_list[0].input_shape dtype = C_list[0].input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.z_list, self.z_list_old = self.z_init(self.x) self.u_list = self.u_init(self.x) diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 5aa742feb..8cf4eac3d 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -30,7 +30,7 @@ ) from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device, is_real_dtype +from scico.numpy.util import is_real_dtype from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -99,8 +99,6 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Computed solution. """ - x0 = ensure_on_device(x0) - @jax.jit def obj(x): out = 0.0 @@ -260,7 +258,6 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Returns: Computed solution. """ - x0 = ensure_on_device(x0) rhs = self.compute_rhs() x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) # type: ignore return x diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 3e7ef86e3..26b049311 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from ._common import Optimizer @@ -115,7 +114,7 @@ def __init__( input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.z, self.z_old = self.z_init(self.x) self.u = self.u_init(self.x) diff --git a/scico/optimize/_padmm.py b/scico/optimize/_padmm.py index 793abf844..ee3d5c516 100644 --- a/scico/optimize/_padmm.py +++ b/scico/optimize/_padmm.py @@ -20,7 +20,6 @@ from scico.linop import Identity, LinearOperator, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from scico.typing import BlockShape, DType, PRNGKey, Shape from ._common import Optimizer @@ -104,14 +103,14 @@ def __init__( if x0 is None: x0 = snp.zeros(xshape, dtype=xdtype) - self.x = ensure_on_device(x0) + self.x = x0 if z0 is None: z0 = snp.zeros(zshape, dtype=zdtype) - self.z = ensure_on_device(z0) + self.z = z0 self.z_old = self.z if u0 is None: u0 = snp.zeros(ushape, dtype=udtype) - self.u = ensure_on_device(u0) + self.u = u0 self.u_old = self.u super().__init__(**kwargs) diff --git a/scico/optimize/_pgm.py b/scico/optimize/_pgm.py index 7e18ae691..8375c0b2d 100644 --- a/scico/optimize/_pgm.py +++ b/scico/optimize/_pgm.py @@ -19,7 +19,6 @@ from scico.functional import Functional from scico.loss import Loss from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device from ._common import Optimizer from ._pgmaux import ( @@ -84,7 +83,7 @@ def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: self.x_step = jax.jit(x_step) - self.x: Union[Array, BlockArray] = ensure_on_device(x0) # current estimate of solution + self.x: Union[Array, BlockArray] = x0 # current estimate of solution super().__init__(**kwargs) @@ -183,7 +182,6 @@ def __init__( **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ - x0 = ensure_on_device(x0) super().__init__(f=f, g=g, L0=L0, x0=x0, step_size=step_size, **kwargs) self.v = x0 diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 60782d021..ba36331c0 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator, jacobian, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from scico.operator import Operator from scico.typing import PRNGKey @@ -132,13 +131,13 @@ def __init__( input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.x_old = self.x if z0 is None: input_shape = C.output_shape dtype = C.output_dtype z0 = snp.zeros(input_shape, dtype=dtype) - self.z = ensure_on_device(z0) + self.z = z0 self.z_old = self.z super().__init__(**kwargs) diff --git a/scico/ray/tune.py b/scico/ray/tune.py index 65435336a..aeb74d7e8 100644 --- a/scico/ray/tune.py +++ b/scico/ray/tune.py @@ -18,6 +18,8 @@ try: import ray.tune + + os.environ["RAY_AIR_NEW_OUTPUT"] = "0" except ImportError: raise ImportError("Could not import ray.tune; please install it.") import ray.air @@ -75,7 +77,7 @@ def run( config: Optional[Dict[str, Any]] = None, hyperopt: bool = True, verbose: bool = True, - local_dir: Optional[str] = None, + storage_path: Optional[str] = None, ) -> ray.tune.ExperimentAnalysis: """Simplified wrapper for `ray.tune.run`_. @@ -109,7 +111,7 @@ def run( running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. - local_dir: Directory in which to save tuning results. Defaults to + storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. @@ -136,12 +138,12 @@ def run( name = run_or_experiment.__name__ name += "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if local_dir is None: + if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" - local_dir = os.path.join(tempfile.gettempdir(), user, "ray_results") + storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") # Record original logger.info logger_info = ray.tune.tune.logger.info @@ -160,7 +162,7 @@ def logger_info_filter(msg, *args, **kwargs): name=name, time_budget_s=time_budget_s, num_samples=num_samples, - local_dir=local_dir, + storage_path=storage_path, resources_per_trial=resources_per_trial, max_concurrent_trials=max_concurrent_trials, reuse_actors=True, @@ -193,7 +195,7 @@ def __init__( reuse_actors: bool = True, hyperopt: bool = True, verbose: bool = True, - local_dir: Optional[str] = None, + storage_path: Optional[str] = None, **kwargs, ): """ @@ -226,7 +228,7 @@ def __init__( running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. - local_dir: Directory in which to save tuning results. Defaults + storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. @@ -263,15 +265,15 @@ def __init__( setattr(tune_config, k, v) name = trainable.__name__ + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if local_dir is None: + if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" - local_dir = os.path.join(tempfile.gettempdir(), user, "ray_results") + storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") run_config = kwargs.pop("run_config", None) - run_config_kwargs = {"name": name, "local_dir": local_dir, "verbose": 0} + run_config_kwargs = {"name": name, "storage_path": storage_path, "verbose": 0} if verbose: run_config_kwargs.update({"verbose": 1, "progress_reporter": _CustomReporter()}) if num_iterations is not None or time_budget is not None: diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 9b4f43f8a..a48fa632b 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -89,9 +89,6 @@ def test_separable_prox(test_separable_obj): def test_separable_grad(test_separable_obj): # Used to restore the warnings after the context is used with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - # Verifies that there is a warning on f.grad and fg.grad np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) @@ -115,7 +112,6 @@ def __init__(self, delta=1.0): class TestNormProx: - normlist = [ functional.L0Norm, functional.L1Norm, @@ -179,7 +175,6 @@ def test_scaled_attrs(self, norm, test_prox_obj): @pytest.mark.parametrize("norm", normlist) def test_scaled_eval(self, norm, alpha, test_prox_obj): - unscaled = norm() scaled = test_prox_obj.scalar * norm() @@ -242,7 +237,6 @@ def test_proj_obj(request): class TestProj: - cnstrlist = [functional.NonNegativeIndicator, functional.L2BallIndicator] sdistlist = [functional.SetDistance, functional.SquaredSetDistance] diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index 0d4473ce0..0af94d2ca 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -54,9 +54,6 @@ def test_separable_prox(test_separable_obj): def test_separable_grad(test_separable_obj): # Used to restore the warnings after the context is used with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - # Verifies that there is a warning on f.grad and fg.grad np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) diff --git a/scico/test/linop/test_convolve.py b/scico/test/linop/test_convolve.py index 1979a4626..31714a9a6 100644 --- a/scico/test/linop/test_convolve.py +++ b/scico/test/linop/test_convolve.py @@ -1,10 +1,8 @@ import operator as op -import warnings import numpy as np import jax -import jax.numpy as jnp import jax.scipy.signal as signal import pytest @@ -39,7 +37,6 @@ def test_eval(self, input_shape, input_dtype, mode, jit): @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): - ndim = len(input_shape) filter_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) @@ -167,17 +164,6 @@ def test_dimension_mismatch(testobj): Convolve(input_shape=(16, 16), h=testobj.psf_A) -def test_ndarray_h(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - h = np.random.randn(3, 3).astype(np.float32) - A = Convolve(input_shape=(16, 16), h=h) - assert isinstance(A.h, jnp.ndarray) - - class TestConvolveByX: def setup_method(self, method): self.key = jax.random.PRNGKey(12345) @@ -204,7 +190,6 @@ def test_eval(self, input_shape, input_dtype, mode, jit): @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): - ndim = len(input_shape) x_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) @@ -330,14 +315,3 @@ def test_dimension_mismatch(cbx_testobj): with pytest.raises(ValueError): # 2-dim input shape, 1-dim xer ConvolveByX(input_shape=(16, 16), x=cbx_testobj.x_A) - - -def test_ndarray_x(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - x = np.random.randn(3, 3).astype(np.float32) - A = ConvolveByX(input_shape=(16, 16), x=x) - assert isinstance(A.x, jnp.ndarray) diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index 178c1fce5..ec7a24296 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -239,10 +239,14 @@ def test_matmul_identity(self): I = linop.Identity(input_shape=(6,)) assert Ao == Ao @ I - def test_init_devicearray(self): - A = np.random.randn(4, 6) - Ao = MatrixOperator(A) - assert isinstance(Ao.A, jnp.ndarray) + def test_init_array(self): + Am = np.random.randn(4, 6) + A = MatrixOperator(Am) + assert isinstance(A.A, np.ndarray) + + A = MatrixOperator(jnp.array(Am)) + assert isinstance(A.A, jnp.ndarray) + np.testing.assert_array_equal(A.A, jnp.array(A)) with pytest.raises(TypeError): MatrixOperator([1.0, 3.0]) diff --git a/scico/test/linop/test_stack.py b/scico/test/linop/test_stack.py index e410bd9ac..cd59b73ba 100644 --- a/scico/test/linop/test_stack.py +++ b/scico/test/linop/test_stack.py @@ -27,8 +27,8 @@ def test_construct(self, jit): H = VerticalStack([A, B], jit=jit) # in general, returns a BlockArray - A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((3, 3)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x @@ -39,8 +39,8 @@ def test_construct(self, jit): assert np.allclose(y[1], B @ x) # by default, collapse to jax array when possible - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x @@ -51,8 +51,8 @@ def test_construct(self, jit): assert np.allclose(y[1], B @ x) # let user turn off collapsing - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=False, jit=jit) x = np.ones((7, 11)) y = H @ x @@ -62,14 +62,14 @@ def test_construct(self, jit): @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, collapse, jit): # general case - A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((3, 3)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) adjoint_test(H, self.key) # collapsable case - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) adjoint_test(H, self.key) @@ -77,12 +77,12 @@ def test_adjoint(self, collapse, jit): @pytest.mark.parametrize("jit", [False, True]) def test_algebra(self, collapse, jit): # adding - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) - A = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11)) - B = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11)) + A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) + B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) G = VerticalStack([A, B], collapse=collapse, jit=jit) x = np.ones((7, 11)) diff --git a/scico/test/linop/xray/test_svmbir.py b/scico/test/linop/xray/test_svmbir.py index 55865a8bc..9674269c0 100644 --- a/scico/test/linop/xray/test_svmbir.py +++ b/scico/test/linop/xray/test_svmbir.py @@ -249,7 +249,7 @@ def test_prox_cg( mask = np.ones(im.shape) > 0 W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") - W = jax.device_put(W) + W = snp.array(W) λ = 0.01 if is_masked: @@ -297,7 +297,7 @@ def test_approx_prox( y = A @ im W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") - W = jax.device_put(W) + W = snp.array(W) λ = 0.01 v, _ = scico.random.normal(im.shape, dtype=im.dtype) diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index b246c7416..8795efb9a 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -20,7 +18,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(16, 17).astype(np.float32)) + self.y = snp.array(np.random.randn(16, 17).astype(np.float32)) def test_admm(self): maxiter = 2 @@ -112,14 +110,14 @@ def setup_method(self, method): MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(MA, N) - Bmx = np.random.randn(MB, N) - y = np.random.randn(MA) + Amx = np.random.randn(MA, N).astype(np.float32) + Bmx = np.random.randn(MB, N).astype(np.float32) + y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y @@ -219,16 +217,16 @@ def setup_method(self, method): MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_W^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(MA, N) - W = np.abs(np.random.randn(MA, 1)) - Bmx = np.random.randn(MB, N) - y = np.random.randn(MA) + Amx = np.random.randn(MA, N).astype(np.float32) + W = np.abs(np.random.randn(MA, 1).astype(np.float32)) + Bmx = np.random.randn(MB, N).astype(np.float32) + y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = np.e self.Amx = Amx - self.W = jax.device_put(W) + self.W = snp.array(W) self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system diff --git a/scico/test/optimize/test_ladmm.py b/scico/test/optimize/test_ladmm.py index 6ef71b108..6c8a6c708 100644 --- a/scico/test/optimize/test_ladmm.py +++ b/scico/test/optimize/test_ladmm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.μ = 1e-1 self.ν = 1e-1 @@ -122,7 +120,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -161,7 +159,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_padmm.py b/scico/test/optimize/test_padmm.py index a54a0125f..bd8618a22 100644 --- a/scico/test/optimize/test_padmm.py +++ b/scico/test/optimize/test_padmm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.ρ = 1e0 self.μ = 1e0 @@ -199,7 +197,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -267,7 +265,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 62391bd70..61a9353d0 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.τ = 1e-1 self.σ = 1e-1 @@ -128,7 +126,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -189,7 +187,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_pgm.py b/scico/test/optimize/test_pgm.py index 63116381f..1231a826f 100644 --- a/scico/test/optimize/test_pgm.py +++ b/scico/test/optimize/test_pgm.py @@ -4,6 +4,7 @@ import pytest +import scico.numpy as snp from scico import functional, linop, loss, random from scico.optimize import PGM, AcceleratedPGM from scico.optimize.pgm import ( @@ -20,9 +21,9 @@ def setup_method(self, method): M = 5 N = 4 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(M, N) + Amx = np.random.randn(M, N).astype(np.float32) Bmx = np.identity(N) - y = jax.device_put(np.random.randn(M)) + y = snp.array(np.random.randn(M).astype(np.float32)) λ = 1e0 self.Amx = Amx self.y = y @@ -196,7 +197,7 @@ def setup_method(self, method): # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||x||_2^2 Amx, key = random.randn((M, N), dtype=np.complex64, key=None) Bmx = np.identity(N) - y = jax.device_put(np.random.randn(M)) + y = snp.array(np.random.randn(M)) λ = 1e0 self.Amx = Amx self.Bmx = Bmx diff --git a/scico/test/test_numpy_util.py b/scico/test/test_numpy_util.py index be1c9f5ab..faab01dc5 100644 --- a/scico/test/test_numpy_util.py +++ b/scico/test/test_numpy_util.py @@ -1,16 +1,10 @@ -import warnings - import numpy as np -import jax.numpy as jnp - import pytest import scico.numpy as snp -from scico.numpy import BlockArray from scico.numpy.util import ( complex_dtype, - ensure_on_device, indexed_shape, is_complex_dtype, is_nested, @@ -24,31 +18,6 @@ from scico.random import randn -def test_ensure_on_device(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - NP = np.ones(2) - SNP = snp.ones(2) - BA = snp.blockarray([NP, SNP]) - NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) - - assert isinstance(NP_, jnp.ndarray) - - assert isinstance(SNP_, jnp.ndarray) - assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() - - assert isinstance(BA_, BlockArray) - assert isinstance(BA_[0], jnp.ndarray) - assert isinstance(BA_[1], jnp.ndarray) - assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() - - NP_ = ensure_on_device(NP) - assert isinstance(NP_, jnp.ndarray) - - def test_no_nan_divide_array(): x, key = randn((4,), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index dde5b1d37..8b33f47ef 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -7,7 +7,7 @@ try: import ray - from scico.ray import train, tune + from scico.ray import report, tune ray.init(num_cpus=1) except ImportError as e: @@ -18,7 +18,7 @@ def test_random_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -32,7 +32,7 @@ def eval_params(config): resources_per_trial=resources, hyperopt=False, verbose=False, - local_dir=os.path.join(tempfile.gettempdir(), "ray_test"), + storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) best_config = analysis.get_best_config(metric="cost", mode="min") assert np.abs(best_config["x"]) < 0.25 @@ -43,7 +43,7 @@ def test_random_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -56,7 +56,7 @@ def eval_params(config): num_samples=100, hyperopt=False, verbose=False, - local_dir=os.path.join(tempfile.gettempdir(), "ray_test"), + storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) results = tuner.fit() best_config = results.get_best_result().config @@ -68,7 +68,7 @@ def test_hyperopt_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -91,7 +91,7 @@ def test_hyperopt_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -115,7 +115,7 @@ def test_hyperopt_tune_alt_init(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} tuner = tune.Tuner( diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index f220482df..ebbcbf9c5 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -15,8 +15,8 @@ def setup_method(self, method): def test_wrap_func_and_grad(self): N = 8 - A = jax.device_put(np.random.randn(N, N)) - x = jax.device_put(np.random.randn(N)) + A = snp.array(np.random.randn(N, N)) + x = snp.array(np.random.randn(N)) f = lambda x: 0.5 * snp.linalg.norm(A @ x) ** 2 @@ -117,10 +117,10 @@ def test_preconditioned_cg(self): def test_lstsq_func(self): N = 24 M = 32 - Ac = jax.device_put(np.random.randn(N, M).astype(np.float32)) + Ac = snp.array(np.random.randn(N, M).astype(np.float32)) Am = Ac.dot(Ac.T) A = Am.dot - x = jax.device_put(np.random.randn(N).astype(np.float32)) + x = snp.array(np.random.randn(N).astype(np.float32)) b = Am.dot(x) x0 = snp.zeros((N,), dtype=np.float32) tol = 1e-6 @@ -134,9 +134,9 @@ def test_lstsq_func(self): def test_lstsq_op(self): N = 32 M = 24 - Ac = jax.device_put(np.random.randn(N, M).astype(np.float32)) + Ac = snp.array(np.random.randn(N, M).astype(np.float32)) A = linop.MatrixOperator(Ac) - x = jax.device_put(np.random.randn(M).astype(np.float32)) + x = snp.array(np.random.randn(M).astype(np.float32)) b = Ac.dot(x) tol = 1e-7 try: