From aadb4ada0683af06dd18ba1f9841237d4499c81e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 30 Oct 2023 15:45:32 -0600 Subject: [PATCH 1/3] Address DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations (#455) --- scico/flax/train/input_pipeline.py | 2 +- scico/flax/train/state.py | 2 +- scico/flax/train/steps.py | 2 +- scico/flax/train/trainer.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scico/flax/train/input_pipeline.py b/scico/flax/train/input_pipeline.py index abb966233..757ac7a5e 100644 --- a/scico/flax/train/input_pipeline.py +++ b/scico/flax/train/input_pipeline.py @@ -26,7 +26,7 @@ from .typed_dict import DataSetDict DType = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] class IterateData: diff --git a/scico/flax/train/state.py b/scico/flax/train/state.py index 21dab6ec0..9c6952cb4 100644 --- a/scico/flax/train/state.py +++ b/scico/flax/train/state.py @@ -21,7 +21,7 @@ from .typed_dict import ConfigDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any ArrayTree = optax.Params diff --git a/scico/flax/train/steps.py b/scico/flax/train/steps.py index 8b3df81ac..8901e1881 100644 --- a/scico/flax/train/steps.py +++ b/scico/flax/train/steps.py @@ -19,7 +19,7 @@ from .state import TrainState from .typed_dict import DataSetDict, MetricsDict -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any diff --git a/scico/flax/train/trainer.py b/scico/flax/train/trainer.py index 8ce144214..93cf97b84 100644 --- a/scico/flax/train/trainer.py +++ b/scico/flax/train/trainer.py @@ -47,10 +47,11 @@ from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any DType = Any + # sync across replicas def sync_batch_stats(state: TrainState) -> TrainState: """Sync the batch statistics across replicas.""" From d0e6c4c5d7a9d938733b3c548fdbe614a50ebaae Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 2 Nov 2023 12:31:05 -0600 Subject: [PATCH 2/3] Updates required by recent changes in `ray` (#462) * Resolve ray.tune deprecation warnings * Resolve ray-project/ray#38202 * Change ray version requirements * Fix tests --- examples/examples_requirements.txt | 2 +- scico/ray/tune.py | 22 ++++++++++++---------- scico/test/test_ray_tune.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index 60b77e7e7..ff7a1dcfb 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,7 +1,7 @@ astra-toolbox colour_demosaicing xdesign>=0.5.5 -ray[tune]>=2.0.0 +ray[tune,train]>=2.5.0 hyperopt bm3d>=4.0.0 bm4d>=4.2.2 diff --git a/scico/ray/tune.py b/scico/ray/tune.py index 65435336a..aeb74d7e8 100644 --- a/scico/ray/tune.py +++ b/scico/ray/tune.py @@ -18,6 +18,8 @@ try: import ray.tune + + os.environ["RAY_AIR_NEW_OUTPUT"] = "0" except ImportError: raise ImportError("Could not import ray.tune; please install it.") import ray.air @@ -75,7 +77,7 @@ def run( config: Optional[Dict[str, Any]] = None, hyperopt: bool = True, verbose: bool = True, - local_dir: Optional[str] = None, + storage_path: Optional[str] = None, ) -> ray.tune.ExperimentAnalysis: """Simplified wrapper for `ray.tune.run`_. @@ -109,7 +111,7 @@ def run( running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. - local_dir: Directory in which to save tuning results. Defaults to + storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. @@ -136,12 +138,12 @@ def run( name = run_or_experiment.__name__ name += "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if local_dir is None: + if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" - local_dir = os.path.join(tempfile.gettempdir(), user, "ray_results") + storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") # Record original logger.info logger_info = ray.tune.tune.logger.info @@ -160,7 +162,7 @@ def logger_info_filter(msg, *args, **kwargs): name=name, time_budget_s=time_budget_s, num_samples=num_samples, - local_dir=local_dir, + storage_path=storage_path, resources_per_trial=resources_per_trial, max_concurrent_trials=max_concurrent_trials, reuse_actors=True, @@ -193,7 +195,7 @@ def __init__( reuse_actors: bool = True, hyperopt: bool = True, verbose: bool = True, - local_dir: Optional[str] = None, + storage_path: Optional[str] = None, **kwargs, ): """ @@ -226,7 +228,7 @@ def __init__( running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. - local_dir: Directory in which to save tuning results. Defaults + storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. @@ -263,15 +265,15 @@ def __init__( setattr(tune_config, k, v) name = trainable.__name__ + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if local_dir is None: + if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" - local_dir = os.path.join(tempfile.gettempdir(), user, "ray_results") + storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") run_config = kwargs.pop("run_config", None) - run_config_kwargs = {"name": name, "local_dir": local_dir, "verbose": 0} + run_config_kwargs = {"name": name, "storage_path": storage_path, "verbose": 0} if verbose: run_config_kwargs.update({"verbose": 1, "progress_reporter": _CustomReporter()}) if num_iterations is not None or time_budget is not None: diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index dde5b1d37..8b33f47ef 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -7,7 +7,7 @@ try: import ray - from scico.ray import train, tune + from scico.ray import report, tune ray.init(num_cpus=1) except ImportError as e: @@ -18,7 +18,7 @@ def test_random_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -32,7 +32,7 @@ def eval_params(config): resources_per_trial=resources, hyperopt=False, verbose=False, - local_dir=os.path.join(tempfile.gettempdir(), "ray_test"), + storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) best_config = analysis.get_best_config(metric="cost", mode="min") assert np.abs(best_config["x"]) < 0.25 @@ -43,7 +43,7 @@ def test_random_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -56,7 +56,7 @@ def eval_params(config): num_samples=100, hyperopt=False, verbose=False, - local_dir=os.path.join(tempfile.gettempdir(), "ray_test"), + storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) results = tuner.fit() best_config = results.get_best_result().config @@ -68,7 +68,7 @@ def test_hyperopt_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -91,7 +91,7 @@ def test_hyperopt_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -115,7 +115,7 @@ def test_hyperopt_tune_alt_init(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - train.report({"cost": cost}) + report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} tuner = tune.Tuner( From bcb1eabe6ea0c4c4cbe58201075e7c57eded9b8d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 2 Nov 2023 13:20:35 -0600 Subject: [PATCH 3/3] Remove `ensure_on_device` function (#463) * Remove jax.device_put calls and some clean up * Remove ensure_on_device calls * Minor edit * Update submodule * Remove unnecessary forced conversion to jax array * Delete ensure_on_device function * Fix compatibility issue with older jax versions * Update submodule * Clean up * Remove jax.device_put calls * Address PR review comment * Update submodule --- data | 2 +- examples/scripts/ct_abel_tv_admm.py | 2 +- examples/scripts/ct_abel_tv_admm_tune.py | 21 +++++----- examples/scripts/ct_astra_3d_tv_admm.py | 11 +++-- examples/scripts/ct_astra_modl_train_foam2.py | 2 +- examples/scripts/ct_astra_noreg_pcg.py | 3 +- examples/scripts/ct_astra_tv_admm.py | 6 +-- examples/scripts/ct_astra_weighted_tv_admm.py | 8 ++-- .../ct_fan_svmbir_ppp_bm3d_admm_prox.py | 18 ++++----- examples/scripts/ct_projector_comparison.py | 36 ++++++++--------- .../scripts/ct_svmbir_ppp_bm3d_admm_cg.py | 6 +-- .../scripts/ct_svmbir_ppp_bm3d_admm_prox.py | 16 ++++---- examples/scripts/ct_svmbir_tv_multi.py | 9 ++--- examples/scripts/deconv_circ_tv_admm.py | 3 -- .../deconv_microscopy_allchn_tv_admm.py | 8 ++-- examples/scripts/deconv_microscopy_tv_admm.py | 1 - examples/scripts/deconv_modl_train_foam1.py | 2 +- examples/scripts/deconv_odp_train_foam1.py | 8 ++-- examples/scripts/deconv_ppp_bm3d_admm.py | 4 +- examples/scripts/deconv_ppp_bm3d_pgm.py | 4 +- examples/scripts/deconv_ppp_bm4d_admm.py | 4 +- examples/scripts/deconv_ppp_dncnn_admm.py | 4 +- examples/scripts/deconv_ppp_dncnn_padmm.py | 4 +- examples/scripts/deconv_tv_admm.py | 2 - examples/scripts/deconv_tv_admm_tune.py | 3 -- examples/scripts/deconv_tv_padmm.py | 4 +- examples/scripts/demosaic_ppp_bm3d_admm.py | 7 +--- examples/scripts/denoise_dncnn_train_bsds.py | 1 + examples/scripts/denoise_dncnn_universal.py | 6 +-- examples/scripts/denoise_l1tv_admm.py | 2 - examples/scripts/denoise_tv_admm.py | 2 - examples/scripts/denoise_tv_multi.py | 2 - examples/scripts/denoise_tv_pgm.py | 8 ---- examples/scripts/diffusercam_tv_admm.py | 6 +-- examples/scripts/sparsecode_admm.py | 7 ++-- examples/scripts/sparsecode_conv_admm.py | 6 +-- examples/scripts/sparsecode_conv_md_admm.py | 6 +-- examples/scripts/sparsecode_pgm.py | 7 ++-- examples/scripts/sparsecode_poisson_pgm.py | 11 ++--- examples/scripts/superres_ppp_dncnn_admm.py | 4 +- scico/linop/_convolve.py | 13 ++---- scico/linop/_diag.py | 5 +-- scico/linop/_matrix.py | 4 +- scico/linop/abel.py | 2 +- scico/linop/optics.py | 10 +---- scico/linop/radon_svmbir.py | 8 ++-- scico/loss.py | 14 +------ scico/numpy/util.py | 40 ------------------- scico/optimize/_admm.py | 3 +- scico/optimize/_admmaux.py | 5 +-- scico/optimize/_ladmm.py | 3 +- scico/optimize/_padmm.py | 7 ++-- scico/optimize/_pgm.py | 4 +- scico/optimize/_primaldual.py | 5 +-- scico/test/functional/test_core.py | 6 --- scico/test/functional/test_separable.py | 3 -- scico/test/linop/test_convolve.py | 26 ------------ scico/test/linop/test_matrix.py | 12 ++++-- scico/test/linop/test_radon_svmbir.py | 4 +- scico/test/linop/test_stack.py | 28 ++++++------- scico/test/optimize/test_admm.py | 24 +++++------ scico/test/optimize/test_ladmm.py | 8 ++-- scico/test/optimize/test_padmm.py | 8 ++-- scico/test/optimize/test_pdhg.py | 8 ++-- scico/test/optimize/test_pgm.py | 7 ++-- scico/test/test_numpy_util.py | 31 -------------- scico/test/test_solver.py | 12 +++--- 67 files changed, 179 insertions(+), 387 deletions(-) diff --git a/data b/data index 1f1e9f83b..b63329c3b 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 1f1e9f83bb52bf9a08115ab71d8bb32a05c4ff0c +Subproject commit b63329c3b1b89fbebc4cb3ec892badee0b989e40 diff --git a/examples/scripts/ct_abel_tv_admm.py b/examples/scripts/ct_abel_tv_admm.py index 89014cb30..97ca30169 100644 --- a/examples/scripts/ct_abel_tv_admm.py +++ b/examples/scripts/ct_abel_tv_admm.py @@ -57,7 +57,7 @@ better performance than isotropic TV for this problem, is used here. """ f = loss.SquaredL2Loss(y=y, A=A) -λ = 2.35e1 # L1 norm regularization parameter +λ = 2.35e1 # ℓ1 norm regularization parameter g = λ * functional.L1Norm() # Note the use of anisotropic TV C = linop.FiniteDifference(input_shape=x_gt.shape) diff --git a/examples/scripts/ct_abel_tv_admm_tune.py b/examples/scripts/ct_abel_tv_admm_tune.py index db29029b1..ab7ffd18f 100644 --- a/examples/scripts/ct_abel_tv_admm_tune.py +++ b/examples/scripts/ct_abel_tv_admm_tune.py @@ -14,14 +14,14 @@ `ray.tune` class API is used in this example. This script is hard-coded to run on CPU only to avoid the large number of -warnings that are emitted when GPU resources are requested but not available, -and due to the difficulty of supressing these warnings in a way that does -not force use of the CPU only. To enable GPU usage, comment out the -`os.environ` statements near the beginning of the script, and change the -value of the "gpu" entry in the `resources` dict from 0 to 1. Note that -two environment variables are set to suppress the warnings because -`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change -has yet to be correctly implemented +warnings that are emitted when GPU resources are requested but not +available, and due to the difficulty of supressing these warnings in a +way that does not force use of the CPU only. To enable GPU usage, comment +out the `os.environ` statements near the beginning of the script, and +change the value of the "gpu" entry in the `resources` dict from 0 to 1. +Note that two environment variables are set to suppress the warnings +because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but +this change has yet to be correctly implemented (see [google/jax#6805](https://github.com/google/jax/issues/6805) and [google/jax#10272](https://github.com/google/jax/pull/10272). """ @@ -34,7 +34,6 @@ import numpy as np -import jax import scico.numpy as snp from scico import functional, linop, loss, metric, plot @@ -82,8 +81,8 @@ def setup(self, config, x_gt, x0, y): this case). The remaining parameters are objects that are passed to the evaluation function via the ray object store. """ - # Put main arrays on jax device. - self.x_gt, self.x0, self.y = jax.device_put([x_gt, x0, y]) + # Get arrays passed by tune call. + self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y) # Set up problem to be solved. self.A = AbelProjector(self.x_gt.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py index c37647d32..3abb9ae89 100644 --- a/examples/scripts/ct_astra_3d_tv_admm.py +++ b/examples/scripts/ct_astra_3d_tv_admm.py @@ -23,10 +23,9 @@ import numpy as np -import jax - from mpl_toolkits.axes_grid1 import make_axes_locatable +import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_tangle_phantom from scico.linop.radon_astra import TomographicProjector @@ -36,13 +35,11 @@ """ Create a ground truth image and projector. """ - Nx = 128 Ny = 256 Nz = 64 -tangle = create_tangle_phantom(Nx, Ny, Nz) -tangle = jax.device_put(tangle) +tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles @@ -55,7 +52,7 @@ """ Set up ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ2,1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance @@ -82,6 +79,7 @@ itstat_options={"display": True, "period": 5}, ) + """ Run the solver. """ @@ -95,6 +93,7 @@ % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) ) + """ Show the recovered image. """ diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index 6d2b2f4ea..4214888ab 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -168,7 +168,7 @@ stats_object_ini = None checkpoint_files = [] -for (dirpath, dirnames, filenames) in os.walk(workdir2): +for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"] if len(checkpoint_files) > 0: diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py index b0b7e2372..fc5dd6f08 100644 --- a/examples/scripts/ct_astra_noreg_pcg.py +++ b/examples/scripts/ct_astra_noreg_pcg.py @@ -23,7 +23,6 @@ import numpy as np -import jax import jax.numpy as jnp from xdesign import Foam, discrete_phantom @@ -38,7 +37,7 @@ """ N = 256 # phantom size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = jnp.array(x_gt) # convert to jax type """ diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index dc694eca6..1f12f7ab3 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -21,8 +21,6 @@ import numpy as np -import jax - from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import Foam, discrete_phantom @@ -38,7 +36,7 @@ N = 512 # phantom size np.random.seed(1234) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = snp.array(x_gt) # convert to jax type """ @@ -53,7 +51,7 @@ """ Set up ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index abae983ce..3f14b828d 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -23,8 +23,6 @@ import numpy as np -import jax - from xdesign import Soil, discrete_phantom import scico.numpy as snp @@ -42,7 +40,7 @@ x_gt = discrete_phantom(Soil(porosity=0.80), size=384) x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64))) x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU +x_gt = snp.array(x_gt) # convert to jax type """ @@ -72,7 +70,7 @@ counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt)) counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1 y = -1 / 𝛼 * np.log(counts / Io) -y = jax.device_put(y) # convert back to float32 +y = snp.array(y) # convert back to float32 as a jax array """ @@ -140,7 +138,7 @@ def postprocess(x): """ lambda_weighted = 5e1 -weights = jax.device_put(counts / Io) +weights = snp.array(counts / Io) f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) admm_weighted = ADMM( diff --git a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py index 24ace60ec..1e334ada5 100644 --- a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py @@ -26,8 +26,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator @@ -137,10 +135,12 @@ def add_poisson_noise(sino, max_intensity): """ -Push arrays to device. +Convert numpy arrays to jax arrays. """ -y_fan, x0_fan, weights_fan = jax.device_put([y_fan, x_mrf_fan, weights_fan]) -x0_parallel = jax.device_put(x_mrf_parallel) +y_fan = snp.array(y_fan) +x0_fan = snp.array(x_mrf_fan) +weights_fan = snp.array(weights_fan) +x0_parallel = snp.array(x_mrf_parallel) """ @@ -179,7 +179,7 @@ def add_poisson_noise(sino, max_intensity): x0=x0_fan, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) solver_extloss_parallel = ADMM( f=None, @@ -189,7 +189,7 @@ def add_poisson_noise(sino, max_intensity): x0=x0_parallel, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -267,7 +267,7 @@ def add_poisson_noise(sino, max_intensity): fig=fig, ax=ax[0], ) -ax[0].set_ylim([5e-3, 1e0]) +ax[0].set_ylim([5e-3, 5e0]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.vstack((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T, @@ -278,7 +278,7 @@ def add_poisson_noise(sino, max_intensity): fig=fig, ax=ax[1], ) -ax[1].set_ylim([5e-3, 1e0]) +ax[1].set_ylim([5e-3, 5e0]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index ab8c43cfa..94b8d1d2c 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -9,8 +9,8 @@ X-ray Projector Comparison ========================== -This example compares SCICO's native X-ray projection algorithm -to that of the ASTRA Toolbox. +This example compares SCICO's native X-ray projection algorithm to that +of the ASTRA Toolbox. """ import numpy as np @@ -30,12 +30,10 @@ """ N = 512 - - det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) - x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) +x_gt = jnp.array(x_gt) + """ Time projector instantiation. @@ -44,7 +42,6 @@ num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) - timer = Timer() projectors = {} @@ -58,10 +55,10 @@ ) timer.stop("astra_init") + """ Time first projector application, which might include JIT overhead. """ - ys = {} for name, H in projectors.items(): timer_label = f"{name}_first_proj" @@ -74,7 +71,6 @@ """ Compute average time for a projector application. """ - num_repeats = 3 for name, H in projectors.items(): timer_label = f"{name}_avg_proj" @@ -85,6 +81,7 @@ timer.stop(timer_label) timer.td[timer_label] /= num_repeats + """ Display timing results. @@ -93,7 +90,7 @@ 10% slower when both are run the CPU. On our server, using the GPU: -``` +``` Label Accum. Current ------------------------------------------- astra_avg_proj 4.62e-02 s Stopped @@ -119,10 +116,10 @@ print(timer) + """ Show projections. """ - fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) @@ -136,7 +133,7 @@ y = np.zeros(H.output_shape, dtype=np.float32) y[num_angles // 3, det_count // 2] = 1.0 -y = jax.device_put(y) +y = jnp.array(y) HTys = {} for name, H in projectors.items(): @@ -160,14 +157,14 @@ timer.stop(timer_label) timer.td[timer_label] /= num_repeats + """ Display back projection timing results. -On our server, the SCICO back projection is slow -the first time it is run, probably due to JIT overhead. -After the first run, it is an order of magnitude -faster than ASTRA when both are run on the GPU, -and about three times faster when both are run on the CPU. +On our server, the SCICO back projection is slow the first time it is +run, probably due to JIT overhead. After the first run, it is an order of +magnitude faster than ASTRA when both are run on the GPU, and about three +times faster when both are run on the CPU. On our server, using the GPU: ``` @@ -192,11 +189,10 @@ print(timer) + """ -Show back projections of a single detector element, -i.e., a line. +Show back projections of a single detector element, i.e., a line. """ - fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index f8b06dfc1..d4b2e6050 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -23,8 +23,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom @@ -88,7 +86,9 @@ """ Set up an ADMM solver. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +y = snp.array(y) +x0 = snp.array(x_mrf) +weights = snp.array(weights) ρ = 15 # ADMM penalty parameter σ = density * 0.18 # denoiser sigma diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 584ff4bc0..787709b86 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -32,8 +32,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator @@ -100,9 +98,11 @@ """ -Push arrays to device. +Convert numpy arrays to jax arrays. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +y = snp.array(y) +x0 = snp.array(x_mrf) +weights = snp.array(weights) """ @@ -129,7 +129,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -161,7 +161,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), - itstat_options={"display": True}, + itstat_options={"display": True, "period": 5}, ) @@ -219,7 +219,7 @@ fig=fig, ax=ax[0], ) -ax[0].set_ylim([5e-3, 1e0]) +ax[0].set_ylim([5e-3, 5e0]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T, @@ -230,7 +230,7 @@ fig=fig, ax=ax[1], ) -ax[1].set_ylim([5e-3, 1e0]) +ax[1].set_ylim([5e-3, 5e0]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 83627acfd..06d99696d 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -22,8 +22,6 @@ import numpy as np -import jax - import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom @@ -65,7 +63,7 @@ expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s -y = -np.log(noisy_counts / max_intensity) +y = -snp.log(noisy_counts / max_intensity) """ @@ -87,9 +85,10 @@ """ Set up problem. """ -y, x0, weights = jax.device_put([y, x_mrf, weights]) +x0 = snp.array(x_mrf) +weights = snp.array(weights) -λ = 1e-1 # L1 norm regularization parameter +λ = 1e-1 # ℓ1 norm regularization parameter f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = λ * functional.L21Norm() # regularization functional diff --git a/examples/scripts/deconv_circ_tv_admm.py b/examples/scripts/deconv_circ_tv_admm.py index d4676a4b6..b2ba83202 100644 --- a/examples/scripts/deconv_circ_tv_admm.py +++ b/examples/scripts/deconv_circ_tv_admm.py @@ -20,8 +20,6 @@ """ -import jax - from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp @@ -36,7 +34,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/deconv_microscopy_allchn_tv_admm.py b/examples/scripts/deconv_microscopy_allchn_tv_admm.py index 244157f4f..13ec27b87 100644 --- a/examples/scripts/deconv_microscopy_allchn_tv_admm.py +++ b/examples/scripts/deconv_microscopy_allchn_tv_admm.py @@ -30,8 +30,6 @@ import numpy as np -import jax - import ray import scico.numpy as snp from scico import functional, linop, loss, plot @@ -108,9 +106,9 @@ @ray.remote(num_cpus=ncpu, num_gpus=ngpu) def deconvolve_channel(channel): """Deconvolve a single channel.""" - y_pad = jax.device_put(ray.get(y_pad_list)[channel]) - psf = jax.device_put(ray.get(psf_list)[channel]) - mask = jax.device_put(ray.get(mask_store)) + y_pad = ray.get(y_pad_list)[channel] + psf = ray.get(psf_list)[channel] + mask = ray.get(mask_store) M = linop.Diagonal(mask) C0 = linop.CircularConvolve( h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5 # forward operator diff --git a/examples/scripts/deconv_microscopy_tv_admm.py b/examples/scripts/deconv_microscopy_tv_admm.py index 917f560fb..eee4d8f42 100644 --- a/examples/scripts/deconv_microscopy_tv_admm.py +++ b/examples/scripts/deconv_microscopy_tv_admm.py @@ -52,7 +52,6 @@ y -= y.min() y /= y.max() - psf /= psf.sum() diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index b0d437eab..96d0b7f15 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -163,7 +163,7 @@ stats_object_ini = None checkpoint_files = [] -for (dirpath, dirnames, filenames) in os.walk(workdir2): +for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"] if len(checkpoint_files) > 0: diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index c96bf9019..38748d806 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -108,10 +108,10 @@ """ Define configuration dictionary for model and training loop. -Parameters have been selected for demonstration purposes and -relatively short training. The model depth is akin to the number of -unrolled iterations in the ODP model. The block depth controls the number -of layers at each unrolled iteration. The number of filters is uniform +Parameters have been selected for demonstration purposes and relatively +short training. The model depth is akin to the number of unrolled +iterations in the ODP model. The block depth controls the number of +layers at each unrolled iteration. The number of filters is uniform throughout the iterations. Better performance may be obtained by increasing depth, block depth, number of filters or training epochs, but may require longer training times. diff --git a/examples/scripts/deconv_ppp_bm3d_admm.py b/examples/scripts/deconv_ppp_bm3d_admm.py index e1c9a7322..ffebfd546 100644 --- a/examples/scripts/deconv_ppp_bm3d_admm.py +++ b/examples/scripts/deconv_ppp_bm3d_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_bm3d_pgm.py b/examples/scripts/deconv_ppp_bm3d_pgm.py index 3448b308d..f8fae40b6 100644 --- a/examples/scripts/deconv_ppp_bm3d_pgm.py +++ b/examples/scripts/deconv_ppp_bm3d_pgm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_bm4d_admm.py b/examples/scripts/deconv_ppp_bm4d_admm.py index 469b468a8..af600b6f9 100644 --- a/examples/scripts/deconv_ppp_bm4d_admm.py +++ b/examples/scripts/deconv_ppp_bm4d_admm.py @@ -17,8 +17,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.examples import create_3d_foam_phantom, downsample_volume, tile_volume_slices @@ -34,7 +32,7 @@ upsamp = 2 x_gt_hires = create_3d_foam_phantom((upsamp * Nz, upsamp * Ny, upsamp * Nx), N_sphere=100) x_gt = downsample_volume(x_gt_hires, upsamp) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_dncnn_admm.py b/examples/scripts/deconv_ppp_dncnn_admm.py index fd6e36900..d131d0300 100644 --- a/examples/scripts/deconv_ppp_dncnn_admm.py +++ b/examples/scripts/deconv_ppp_dncnn_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_ppp_dncnn_padmm.py b/examples/scripts/deconv_ppp_dncnn_padmm.py index 1707158b9..a91fbb75b 100644 --- a/examples/scripts/deconv_ppp_dncnn_padmm.py +++ b/examples/scripts/deconv_ppp_dncnn_padmm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom import scico.numpy as snp @@ -31,7 +29,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ diff --git a/examples/scripts/deconv_tv_admm.py b/examples/scripts/deconv_tv_admm.py index 5d651bd9d..874b70e3c 100644 --- a/examples/scripts/deconv_tv_admm.py +++ b/examples/scripts/deconv_tv_admm.py @@ -22,7 +22,6 @@ ADMM is used in a [companion example](deconv_tv_padmm.rst). """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/deconv_tv_admm_tune.py b/examples/scripts/deconv_tv_admm_tune.py index 809029042..a313a1028 100644 --- a/examples/scripts/deconv_tv_admm_tune.py +++ b/examples/scripts/deconv_tv_admm_tune.py @@ -32,7 +32,6 @@ os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORMS"] = "cpu" -import jax from xdesign import SiemensStar, discrete_phantom @@ -80,8 +79,6 @@ def eval_params(config, x_gt, psf, y): """ # Extract solver parameters from config dict. λ, ρ = config["lambda"], config["rho"] - # Put main arrays on jax device. - x_gt, psf, y = jax.device_put([x_gt, psf, y]) # Set up problem to be solved. A = linop.Convolve(h=psf, input_shape=x_gt.shape) f = loss.SquaredL2Loss(y=y, A=A) diff --git a/examples/scripts/deconv_tv_padmm.py b/examples/scripts/deconv_tv_padmm.py index 8cab50330..af7c4718b 100644 --- a/examples/scripts/deconv_tv_padmm.py +++ b/examples/scripts/deconv_tv_padmm.py @@ -22,7 +22,6 @@ ADMM is used in a [companion example](deconv_tv_admm.rst). """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ @@ -92,7 +90,7 @@ """ f = functional.ZeroFunctional() g0 = loss.SquaredL2Loss(y=y) -λ = 2.0e-2 # L1 norm regularization parameter +λ = 2.0e-2 # ℓ2,1 norm regularization parameter g1 = λ * functional.L21Norm() g = functional.SeparableFunctional((g0, g1)) diff --git a/examples/scripts/demosaic_ppp_bm3d_admm.py b/examples/scripts/demosaic_ppp_bm3d_admm.py index 953b67e88..15f49a50d 100644 --- a/examples/scripts/demosaic_ppp_bm3d_admm.py +++ b/examples/scripts/demosaic_ppp_bm3d_admm.py @@ -16,8 +16,6 @@ import numpy as np -import jax - from bm3d import bm3d_rgb from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 @@ -32,8 +30,7 @@ """ Read a ground truth image. """ -img = kodim23(asfloat=True)[160:416, 60:316] -img = jax.device_put(img) # convert to jax type, push to GPU +img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ @@ -93,7 +90,7 @@ def demosaic(cfaimg): """ Compute a baseline demosaicing solution. """ -imgb = jax.device_put(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)) +imgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)) """ diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index c92dbf9d7..4fd1330a2 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -129,6 +129,7 @@ time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) + """ Compare trained model in terms of reconstruction time and data fidelity. """ diff --git a/examples/scripts/denoise_dncnn_universal.py b/examples/scripts/denoise_dncnn_universal.py index 406baffaf..5799053cb 100644 --- a/examples/scripts/denoise_dncnn_universal.py +++ b/examples/scripts/denoise_dncnn_universal.py @@ -23,10 +23,9 @@ import numpy as np -import jax - from xdesign import Foam, discrete_phantom +import scico.numpy as snp import scico.random from scico import metric, plot from scico.denoiser import DnCNN @@ -37,7 +36,7 @@ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array """ @@ -47,7 +46,6 @@ for σ in [0.06, 0.10, 0.20]: print("------+---------+-------------------------+-------------------------") for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]: - # Instantiate a DnCNN. denoiser = DnCNN(variant=variant) diff --git a/examples/scripts/denoise_l1tv_admm.py b/examples/scripts/denoise_l1tv_admm.py index d69168418..457bec0da 100644 --- a/examples/scripts/denoise_l1tv_admm.py +++ b/examples/scripts/denoise_l1tv_admm.py @@ -20,7 +20,6 @@ operator, and $\mathbf{x}$ is the denoised image. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -39,7 +38,6 @@ phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = 0.5 * x_gt / x_gt.max() -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU y = spnoise(x_gt, 0.5) diff --git a/examples/scripts/denoise_tv_admm.py b/examples/scripts/denoise_tv_admm.py index 3d9b211a3..8e35c3e96 100644 --- a/examples/scripts/denoise_tv_admm.py +++ b/examples/scripts/denoise_tv_admm.py @@ -25,7 +25,6 @@ edges that are not vertical or horizontal. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -41,7 +40,6 @@ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU x_gt = x_gt / x_gt.max() diff --git a/examples/scripts/denoise_tv_multi.py b/examples/scripts/denoise_tv_multi.py index 3e126c961..05663a926 100644 --- a/examples/scripts/denoise_tv_multi.py +++ b/examples/scripts/denoise_tv_multi.py @@ -20,7 +20,6 @@ vectors at each point in the image $\mathbf{x}$. """ -import jax from xdesign import SiemensStar, discrete_phantom @@ -37,7 +36,6 @@ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU """ diff --git a/examples/scripts/denoise_tv_pgm.py b/examples/scripts/denoise_tv_pgm.py index d39b3ab33..fbdf6f017 100644 --- a/examples/scripts/denoise_tv_pgm.py +++ b/examples/scripts/denoise_tv_pgm.py @@ -29,7 +29,6 @@ from typing import Callable, Optional, Union -import jax import jax.numpy as jnp from xdesign import SiemensStar, discrete_phantom @@ -38,7 +37,6 @@ import scico.random from scico import functional, linop, loss, operator, plot from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.util import device_info @@ -48,7 +46,6 @@ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) -x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU x_gt = x_gt / x_gt.max() @@ -88,13 +85,11 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, lmbda: float = 0.5, ): - y = ensure_on_device(y) self.functional = functional.SquaredL2Norm() super().__init__(y=y, A=A, scale=1.0) self.lmbda = lmbda def __call__(self, x: Union[Array, BlockArray]) -> float: - xint = self.y - self.lmbda * self.A(x) return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint) @@ -107,7 +102,6 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: # Evaluation of functional set to zero. class IsoProjector(functional.Functional): - has_eval = True has_prox = True @@ -160,7 +154,6 @@ def prox(self, v: Array, lam: float, **kwargs) -> Array: # Evaluation of functional set to zero. class AnisoProjector(functional.Functional): - has_eval = True has_prox = True @@ -168,7 +161,6 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: return 0.0 def prox(self, v: Array, lam: float, **kwargs) -> Array: - return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v)) diff --git a/examples/scripts/diffusercam_tv_admm.py b/examples/scripts/diffusercam_tv_admm.py index 780b3ee78..103286750 100644 --- a/examples/scripts/diffusercam_tv_admm.py +++ b/examples/scripts/diffusercam_tv_admm.py @@ -41,8 +41,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import ucb_diffusercam_data @@ -88,8 +86,8 @@ `JAX_ENABLE_X64=True` and change `dtype` below to `np.float64`. """ dtype = np.float32 -y = jax.device_put(y.astype(dtype)) -psf = jax.device_put(psf.astype(dtype)) +y = snp.array(y.astype(dtype)) +psf = snp.array(psf.astype(dtype)) """ diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_admm.py index 32f3eafde..829ccbff1 100644 --- a/examples/scripts/sparsecode_admm.py +++ b/examples/scripts/sparsecode_admm.py @@ -21,8 +21,7 @@ import numpy as np -import jax - +import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.admm import ADMM, MatrixSubproblemSolver from scico.util import device_info @@ -45,8 +44,8 @@ xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal -xt = jax.device_put(xt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU +xt = snp.array(xt) # convert to jax array +y = snp.array(y) # convert to jax array """ diff --git a/examples/scripts/sparsecode_conv_admm.py b/examples/scripts/sparsecode_conv_admm.py index e1624bdc3..611e7238b 100644 --- a/examples/scripts/sparsecode_conv_admm.py +++ b/examples/scripts/sparsecode_conv_admm.py @@ -24,8 +24,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom @@ -55,8 +53,8 @@ """ Convert numpy arrays to jax arrays. """ -h = jax.device_put(h) -x0 = jax.device_put(x0) +h = snp.array(h) +x0 = snp.array(x0) """ diff --git a/examples/scripts/sparsecode_conv_md_admm.py b/examples/scripts/sparsecode_conv_md_admm.py index 39587ab03..9f70e9ce3 100644 --- a/examples/scripts/sparsecode_conv_md_admm.py +++ b/examples/scripts/sparsecode_conv_md_admm.py @@ -36,8 +36,6 @@ import numpy as np -import jax - import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom @@ -67,8 +65,8 @@ """ Convert numpy arrays to jax arrays. """ -h = jax.device_put(h) -x0 = jax.device_put(x0) +h = snp.array(h) +x0 = snp.array(x0) """ diff --git a/examples/scripts/sparsecode_pgm.py b/examples/scripts/sparsecode_pgm.py index 8cf0874c1..f5d34e08e 100644 --- a/examples/scripts/sparsecode_pgm.py +++ b/examples/scripts/sparsecode_pgm.py @@ -19,8 +19,7 @@ import numpy as np -import jax - +import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.pgm import AcceleratedPGM from scico.util import device_info @@ -44,8 +43,8 @@ x_gt[idx[0:s]] = np.random.randn(s) y = D @ x_gt + σ * np.random.randn(m) # synthetic signal -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU +x_gt = snp.array(x_gt) # convert to jax array +y = snp.array(y) # convert to jax array """ diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_pgm.py index b9f38b1b8..08cccfed1 100644 --- a/examples/scripts/sparsecode_poisson_pgm.py +++ b/examples/scripts/sparsecode_poisson_pgm.py @@ -32,7 +32,6 @@ $\mathbf{x}^{(0)}$. """ -import jax import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt @@ -68,6 +67,7 @@ D0 = D[:, :n0] D1 = D[:, n0:] + # Define composed operator. class ForwardOperator(Operator): @@ -80,7 +80,6 @@ class ForwardOperator(Operator): """ def __init__(self, input_shape: Shape, D0, D1, jit: bool = True): - self.D0 = D0 self.D1 = D1 @@ -105,9 +104,6 @@ def _eval(self, x: BlockArray) -> BlockArray: lam = A(x_gt) y, key = scico.random.poisson(lam, shape=lam.shape, key=key) # synthetic signal -x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU -y = jax.device_put(y) # convert to jax array, push to GPU - """ Set up the loss function and the regularization. @@ -120,11 +116,10 @@ def _eval(self, x: BlockArray) -> BlockArray: """ -Define common setup: maximum of iterations and initial estimation of solution. +Define common setup: maximum of iterations and initial estimate of solution. """ maxiter = 50 x0, key = scico.random.uniform(((n0,), (n1,)), key=key) -x0 = jax.device_put(x0) # Initial solution estimate """ @@ -329,7 +324,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") -x = solver.solve() # Run the solver. +x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) diff --git a/examples/scripts/superres_ppp_dncnn_admm.py b/examples/scripts/superres_ppp_dncnn_admm.py index fd06cd64a..8132376ec 100644 --- a/examples/scripts/superres_ppp_dncnn_admm.py +++ b/examples/scripts/superres_ppp_dncnn_admm.py @@ -14,7 +14,6 @@ superresolution problem. """ -import jax import scico import scico.numpy as snp @@ -39,8 +38,7 @@ def downsample_image(img, rate): """ Read a ground truth image. """ -img = kodim23(asfloat=True)[160:416, 60:316] -img = jax.device_put(img) +img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index c14363c0d..01f8789de 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -17,13 +17,10 @@ import numpy as np -import jax -import jax.numpy as jnp from jax.dtypes import result_type from jax.scipy.signal import convolve import scico.numpy as snp -from scico.numpy.util import ensure_on_device from scico.typing import DType, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -65,7 +62,7 @@ def __init__( if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}.") - self.h = ensure_on_device(h) + self.h = h if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") @@ -193,12 +190,10 @@ def __init__( if x.ndim != len(input_shape): raise ValueError(f"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.") - if isinstance(x, jnp.ndarray): - self.x = x - elif isinstance(x, np.ndarray): # TODO: this should not be handled at the LinOp level - self.x = jax.device_put(x) - else: + # Ensure that x is a numpy or jax array. + if not snp.util.is_arraylike(x): raise TypeError(f"Expected numpy or jax array, got {type(x)}.") + self.x = x if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index 0e9251f57..0e14a0068 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -17,7 +17,7 @@ import scico.numpy as snp from scico.numpy import Array, BlockArray -from scico.numpy.util import broadcast_nested_shapes, ensure_on_device, is_nested +from scico.numpy.util import broadcast_nested_shapes, is_nested from scico.operator._operator import _wrap_mul_div_scalar from scico.typing import BlockShape, DType, Shape @@ -48,8 +48,7 @@ def __init__( input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. """ - - self.diagonal = ensure_on_device(diagonal) + self.diagonal = diagonal if input_shape is None: input_shape = self.diagonal.shape diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 951c6957e..662df0e49 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -79,10 +79,10 @@ def __init__(self, A: ArrayLike, input_cols: int = 0): """ self.A: snp.Array #: Dense array implementing this matrix - # if A is an ndarray, make sure it gets converted to a jax array + # Ensure that A is a numpy or jax array. if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") - self.A = jnp.array(A) + self.A = A # Can only do rank-2 arrays if A.ndim != 2: diff --git a/scico/linop/abel.py b/scico/linop/abel.py index 7ea96b9cb..9a94b3735 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -127,7 +127,7 @@ def _pyabel_daun_get_proj_matrix(img_shape: Shape) -> jax.Array: direction="forward", verbose=False, ) - return jax.device_put(proj_matrix) + return jnp.array(proj_matrix) # Read abel.tools.symmetry module into a string. diff --git a/scico/linop/optics.py b/scico/linop/optics.py index b508a7230..f2a09799d 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -56,8 +56,6 @@ import numpy as np from numpy.lib.scimath import sqrt # complex sqrt -import jax - from typing_extensions import TypeGuard import scico.numpy as snp @@ -289,9 +287,7 @@ def __init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) - self.phase = jax.device_put( - np.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64) - ) + self.phase = snp.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() @@ -386,9 +382,7 @@ def __init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) - self.phase = jax.device_put( - np.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64) - ) + self.phase = snp.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 1929912cb..6d81b0fb7 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -214,7 +214,7 @@ def _proj( delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: - return jax.device_put( + return snp.array( svmbir.project( np.array(x), np.array(angles), @@ -264,8 +264,8 @@ def _bproj( magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, - ): - return jax.device_put( + ) -> snp.Array: + return snp.array( svmbir.backproject( np.array(y), np.array(angles), @@ -432,7 +432,7 @@ def prox(self, v: snp.Array, lam: float = 1, **kwargs) -> snp.Array: if np.sum(np.isnan(result)): raise ValueError("Result contains NaNs.") - return jax.device_put(result.reshape(self.A.input_shape)) + return snp.array(result.reshape(self.A.input_shape)) class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): diff --git a/scico/loss.py b/scico/loss.py index 0e3cb9964..1a5970f1a 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico import functional, linop, operator from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device, no_nan_divide +from scico.numpy.util import no_nan_divide from scico.scipy.special import gammaln # type: ignore from scico.solver import cg @@ -67,7 +67,7 @@ def __init__( be defined in a derived class. scale: Scaling parameter. Default: 1.0. """ - self.y = ensure_on_device(y) + self.y = y if A is None: # y and x must have same shape A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype) # type: ignore @@ -166,7 +166,6 @@ def __init__( W: Optional[linop.Diagonal] = None, prox_kwargs: Optional[dict] = None, ): - r""" Args: y: Measurement. @@ -175,8 +174,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - self.W: linop.Diagonal if W is None: @@ -289,7 +286,6 @@ def __init__( `self.A` is a :class:`.Identity`. scale: Scaling parameter. Default: 0.5. """ - y = ensure_on_device(y) super().__init__(y=y, A=A, scale=scale) #: Constant term, :math:`\ln(y!)`, in Poisson log likehood. @@ -326,7 +322,6 @@ def __init__( scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): - r""" Args: y: Measurement. @@ -335,8 +330,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): @@ -533,7 +526,6 @@ def __init__( scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): - r""" Args: y: Measurement. @@ -542,8 +534,6 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - y = ensure_on_device(y) - if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 50fefdd4e..54a1c497d 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -10,7 +10,6 @@ from __future__ import annotations -import warnings from math import prod from typing import Any, List, Optional, Tuple, Union @@ -24,45 +23,6 @@ from ._blockarray import BlockArray -def ensure_on_device( - *arrays: Union[np.ndarray, snp.Array, BlockArray] -) -> Union[snp.Array, BlockArray]: - """Cast numpy arrays to jax arrays. - - Cast numpy arrays to jax arrays and leave jax arrays and BlockArrays, - as they are. This is intended to be used when initializing optimizers - and functionals so that all arrays are either jax arrays or - BlockArrays. - - Args: - *arrays: One or more input arrays (numpy array, jax array, or - BlockArray). - - Returns: - Array or arrays, modified where appropriate. - - Raises: - TypeError: If the arrays contain anything that is neither - numpy array, jax array, nor BlockArray. - """ - arrays = list(arrays) - - for i, array in enumerate(arrays): - if isinstance(array, np.ndarray): - warnings.warn( - f"Argument {i+1} of {len(arrays)} is a numpy array. " - "Will cast it to a jax array. " - f"To suppress this warning cast all numpy arrays to jax arrays.", - stacklevel=2, - ) - - arrays[i] = jax.device_put(arrays[i]) - - if len(arrays) == 1: - return arrays[0] - return arrays - - def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None ) -> List[int]: diff --git a/scico/optimize/_admm.py b/scico/optimize/_admm.py index 55862f0cb..eccc3acd1 100644 --- a/scico/optimize/_admm.py +++ b/scico/optimize/_admm.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from ._admmaux import ( FBlockCircularConvolveSolver, @@ -142,7 +141,7 @@ def __init__( input_shape = C_list[0].input_shape dtype = C_list[0].input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.z_list, self.z_list_old = self.z_init(self.x) self.u_list = self.u_init(self.x) diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 7a1c8710c..5e7808989 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -30,7 +30,7 @@ ) from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device, is_real_dtype +from scico.numpy.util import is_real_dtype from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -99,8 +99,6 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Computed solution. """ - x0 = ensure_on_device(x0) - @jax.jit def obj(x): out = 0.0 @@ -261,7 +259,6 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Returns: Computed solution. """ - x0 = ensure_on_device(x0) rhs = self.compute_rhs() x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) # type: ignore return x diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 3e7ef86e3..26b049311 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from ._common import Optimizer @@ -115,7 +114,7 @@ def __init__( input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.z, self.z_old = self.z_init(self.x) self.u = self.u_init(self.x) diff --git a/scico/optimize/_padmm.py b/scico/optimize/_padmm.py index 793abf844..ee3d5c516 100644 --- a/scico/optimize/_padmm.py +++ b/scico/optimize/_padmm.py @@ -20,7 +20,6 @@ from scico.linop import Identity, LinearOperator, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from scico.typing import BlockShape, DType, PRNGKey, Shape from ._common import Optimizer @@ -104,14 +103,14 @@ def __init__( if x0 is None: x0 = snp.zeros(xshape, dtype=xdtype) - self.x = ensure_on_device(x0) + self.x = x0 if z0 is None: z0 = snp.zeros(zshape, dtype=zdtype) - self.z = ensure_on_device(z0) + self.z = z0 self.z_old = self.z if u0 is None: u0 = snp.zeros(ushape, dtype=udtype) - self.u = ensure_on_device(u0) + self.u = u0 self.u_old = self.u super().__init__(**kwargs) diff --git a/scico/optimize/_pgm.py b/scico/optimize/_pgm.py index 7e18ae691..8375c0b2d 100644 --- a/scico/optimize/_pgm.py +++ b/scico/optimize/_pgm.py @@ -19,7 +19,6 @@ from scico.functional import Functional from scico.loss import Loss from scico.numpy import Array, BlockArray -from scico.numpy.util import ensure_on_device from ._common import Optimizer from ._pgmaux import ( @@ -84,7 +83,7 @@ def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: self.x_step = jax.jit(x_step) - self.x: Union[Array, BlockArray] = ensure_on_device(x0) # current estimate of solution + self.x: Union[Array, BlockArray] = x0 # current estimate of solution super().__init__(**kwargs) @@ -183,7 +182,6 @@ def __init__( **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ - x0 = ensure_on_device(x0) super().__init__(f=f, g=g, L0=L0, x0=x0, step_size=step_size, **kwargs) self.v = x0 diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 60782d021..ba36331c0 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -18,7 +18,6 @@ from scico.linop import LinearOperator, jacobian, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device from scico.operator import Operator from scico.typing import PRNGKey @@ -132,13 +131,13 @@ def __init__( input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) - self.x = ensure_on_device(x0) + self.x = x0 self.x_old = self.x if z0 is None: input_shape = C.output_shape dtype = C.output_dtype z0 = snp.zeros(input_shape, dtype=dtype) - self.z = ensure_on_device(z0) + self.z = z0 self.z_old = self.z super().__init__(**kwargs) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 9b4f43f8a..a48fa632b 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -89,9 +89,6 @@ def test_separable_prox(test_separable_obj): def test_separable_grad(test_separable_obj): # Used to restore the warnings after the context is used with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - # Verifies that there is a warning on f.grad and fg.grad np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) @@ -115,7 +112,6 @@ def __init__(self, delta=1.0): class TestNormProx: - normlist = [ functional.L0Norm, functional.L1Norm, @@ -179,7 +175,6 @@ def test_scaled_attrs(self, norm, test_prox_obj): @pytest.mark.parametrize("norm", normlist) def test_scaled_eval(self, norm, alpha, test_prox_obj): - unscaled = norm() scaled = test_prox_obj.scalar * norm() @@ -242,7 +237,6 @@ def test_proj_obj(request): class TestProj: - cnstrlist = [functional.NonNegativeIndicator, functional.L2BallIndicator] sdistlist = [functional.SetDistance, functional.SquaredSetDistance] diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index 0d4473ce0..0af94d2ca 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -54,9 +54,6 @@ def test_separable_prox(test_separable_obj): def test_separable_grad(test_separable_obj): # Used to restore the warnings after the context is used with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - # Verifies that there is a warning on f.grad and fg.grad np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) diff --git a/scico/test/linop/test_convolve.py b/scico/test/linop/test_convolve.py index 1979a4626..31714a9a6 100644 --- a/scico/test/linop/test_convolve.py +++ b/scico/test/linop/test_convolve.py @@ -1,10 +1,8 @@ import operator as op -import warnings import numpy as np import jax -import jax.numpy as jnp import jax.scipy.signal as signal import pytest @@ -39,7 +37,6 @@ def test_eval(self, input_shape, input_dtype, mode, jit): @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): - ndim = len(input_shape) filter_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) @@ -167,17 +164,6 @@ def test_dimension_mismatch(testobj): Convolve(input_shape=(16, 16), h=testobj.psf_A) -def test_ndarray_h(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - h = np.random.randn(3, 3).astype(np.float32) - A = Convolve(input_shape=(16, 16), h=h) - assert isinstance(A.h, jnp.ndarray) - - class TestConvolveByX: def setup_method(self, method): self.key = jax.random.PRNGKey(12345) @@ -204,7 +190,6 @@ def test_eval(self, input_shape, input_dtype, mode, jit): @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): - ndim = len(input_shape) x_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) @@ -330,14 +315,3 @@ def test_dimension_mismatch(cbx_testobj): with pytest.raises(ValueError): # 2-dim input shape, 1-dim xer ConvolveByX(input_shape=(16, 16), x=cbx_testobj.x_A) - - -def test_ndarray_x(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - x = np.random.randn(3, 3).astype(np.float32) - A = ConvolveByX(input_shape=(16, 16), x=x) - assert isinstance(A.x, jnp.ndarray) diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index 178c1fce5..ec7a24296 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -239,10 +239,14 @@ def test_matmul_identity(self): I = linop.Identity(input_shape=(6,)) assert Ao == Ao @ I - def test_init_devicearray(self): - A = np.random.randn(4, 6) - Ao = MatrixOperator(A) - assert isinstance(Ao.A, jnp.ndarray) + def test_init_array(self): + Am = np.random.randn(4, 6) + A = MatrixOperator(Am) + assert isinstance(A.A, np.ndarray) + + A = MatrixOperator(jnp.array(Am)) + assert isinstance(A.A, jnp.ndarray) + np.testing.assert_array_equal(A.A, jnp.array(A)) with pytest.raises(TypeError): MatrixOperator([1.0, 3.0]) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 608cc82d0..a41629d6a 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -249,7 +249,7 @@ def test_prox_cg( mask = np.ones(im.shape) > 0 W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") - W = jax.device_put(W) + W = snp.array(W) λ = 0.01 if is_masked: @@ -297,7 +297,7 @@ def test_approx_prox( y = A @ im W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") - W = jax.device_put(W) + W = snp.array(W) λ = 0.01 v, _ = scico.random.normal(im.shape, dtype=im.dtype) diff --git a/scico/test/linop/test_stack.py b/scico/test/linop/test_stack.py index e410bd9ac..cd59b73ba 100644 --- a/scico/test/linop/test_stack.py +++ b/scico/test/linop/test_stack.py @@ -27,8 +27,8 @@ def test_construct(self, jit): H = VerticalStack([A, B], jit=jit) # in general, returns a BlockArray - A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((3, 3)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x @@ -39,8 +39,8 @@ def test_construct(self, jit): assert np.allclose(y[1], B @ x) # by default, collapse to jax array when possible - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x @@ -51,8 +51,8 @@ def test_construct(self, jit): assert np.allclose(y[1], B @ x) # let user turn off collapsing - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=False, jit=jit) x = np.ones((7, 11)) y = H @ x @@ -62,14 +62,14 @@ def test_construct(self, jit): @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, collapse, jit): # general case - A = Convolve(jax.device_put(np.ones((3, 3))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((3, 3)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) adjoint_test(H, self.key) # collapsable case - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) adjoint_test(H, self.key) @@ -77,12 +77,12 @@ def test_adjoint(self, collapse, jit): @pytest.mark.parametrize("jit", [False, True]) def test_algebra(self, collapse, jit): # adding - A = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) - B = Convolve(jax.device_put(np.ones((2, 2))), (7, 11)) + A = Convolve(snp.ones((2, 2)), (7, 11)) + B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse=collapse, jit=jit) - A = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11)) - B = Convolve(jax.device_put(np.random.rand(2, 2)), (7, 11)) + A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) + B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) G = VerticalStack([A, B], collapse=collapse, jit=jit) x = np.ones((7, 11)) diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index b246c7416..8795efb9a 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -20,7 +18,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(16, 17).astype(np.float32)) + self.y = snp.array(np.random.randn(16, 17).astype(np.float32)) def test_admm(self): maxiter = 2 @@ -112,14 +110,14 @@ def setup_method(self, method): MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(MA, N) - Bmx = np.random.randn(MB, N) - y = np.random.randn(MA) + Amx = np.random.randn(MA, N).astype(np.float32) + Bmx = np.random.randn(MB, N).astype(np.float32) + y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y @@ -219,16 +217,16 @@ def setup_method(self, method): MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_W^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(MA, N) - W = np.abs(np.random.randn(MA, 1)) - Bmx = np.random.randn(MB, N) - y = np.random.randn(MA) + Amx = np.random.randn(MA, N).astype(np.float32) + W = np.abs(np.random.randn(MA, 1).astype(np.float32)) + Bmx = np.random.randn(MB, N).astype(np.float32) + y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = np.e self.Amx = Amx - self.W = jax.device_put(W) + self.W = snp.array(W) self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system diff --git a/scico/test/optimize/test_ladmm.py b/scico/test/optimize/test_ladmm.py index 6ef71b108..6c8a6c708 100644 --- a/scico/test/optimize/test_ladmm.py +++ b/scico/test/optimize/test_ladmm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.μ = 1e-1 self.ν = 1e-1 @@ -122,7 +120,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -161,7 +159,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_padmm.py b/scico/test/optimize/test_padmm.py index a54a0125f..bd8618a22 100644 --- a/scico/test/optimize/test_padmm.py +++ b/scico/test/optimize/test_padmm.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.ρ = 1e0 self.μ = 1e0 @@ -199,7 +197,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -267,7 +265,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 62391bd70..61a9353d0 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -1,7 +1,5 @@ import numpy as np -import jax - import pytest import scico.numpy as snp @@ -13,7 +11,7 @@ class TestMisc: def setup_method(self, method): np.random.seed(12345) - self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.τ = 1e-1 self.σ = 1e-1 @@ -128,7 +126,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x @@ -189,7 +187,7 @@ def setup_method(self, method): λ = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x diff --git a/scico/test/optimize/test_pgm.py b/scico/test/optimize/test_pgm.py index 63116381f..1231a826f 100644 --- a/scico/test/optimize/test_pgm.py +++ b/scico/test/optimize/test_pgm.py @@ -4,6 +4,7 @@ import pytest +import scico.numpy as snp from scico import functional, linop, loss, random from scico.optimize import PGM, AcceleratedPGM from scico.optimize.pgm import ( @@ -20,9 +21,9 @@ def setup_method(self, method): M = 5 N = 4 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 - Amx = np.random.randn(M, N) + Amx = np.random.randn(M, N).astype(np.float32) Bmx = np.identity(N) - y = jax.device_put(np.random.randn(M)) + y = snp.array(np.random.randn(M).astype(np.float32)) λ = 1e0 self.Amx = Amx self.y = y @@ -196,7 +197,7 @@ def setup_method(self, method): # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||x||_2^2 Amx, key = random.randn((M, N), dtype=np.complex64, key=None) Bmx = np.identity(N) - y = jax.device_put(np.random.randn(M)) + y = snp.array(np.random.randn(M)) λ = 1e0 self.Amx = Amx self.Bmx = Bmx diff --git a/scico/test/test_numpy_util.py b/scico/test/test_numpy_util.py index be1c9f5ab..faab01dc5 100644 --- a/scico/test/test_numpy_util.py +++ b/scico/test/test_numpy_util.py @@ -1,16 +1,10 @@ -import warnings - import numpy as np -import jax.numpy as jnp - import pytest import scico.numpy as snp -from scico.numpy import BlockArray from scico.numpy.util import ( complex_dtype, - ensure_on_device, indexed_shape, is_complex_dtype, is_nested, @@ -24,31 +18,6 @@ from scico.random import randn -def test_ensure_on_device(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - NP = np.ones(2) - SNP = snp.ones(2) - BA = snp.blockarray([NP, SNP]) - NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) - - assert isinstance(NP_, jnp.ndarray) - - assert isinstance(SNP_, jnp.ndarray) - assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() - - assert isinstance(BA_, BlockArray) - assert isinstance(BA_[0], jnp.ndarray) - assert isinstance(BA_[1], jnp.ndarray) - assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() - - NP_ = ensure_on_device(NP) - assert isinstance(NP_, jnp.ndarray) - - def test_no_nan_divide_array(): x, key = randn((4,), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index f220482df..ebbcbf9c5 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -15,8 +15,8 @@ def setup_method(self, method): def test_wrap_func_and_grad(self): N = 8 - A = jax.device_put(np.random.randn(N, N)) - x = jax.device_put(np.random.randn(N)) + A = snp.array(np.random.randn(N, N)) + x = snp.array(np.random.randn(N)) f = lambda x: 0.5 * snp.linalg.norm(A @ x) ** 2 @@ -117,10 +117,10 @@ def test_preconditioned_cg(self): def test_lstsq_func(self): N = 24 M = 32 - Ac = jax.device_put(np.random.randn(N, M).astype(np.float32)) + Ac = snp.array(np.random.randn(N, M).astype(np.float32)) Am = Ac.dot(Ac.T) A = Am.dot - x = jax.device_put(np.random.randn(N).astype(np.float32)) + x = snp.array(np.random.randn(N).astype(np.float32)) b = Am.dot(x) x0 = snp.zeros((N,), dtype=np.float32) tol = 1e-6 @@ -134,9 +134,9 @@ def test_lstsq_func(self): def test_lstsq_op(self): N = 32 M = 24 - Ac = jax.device_put(np.random.randn(N, M).astype(np.float32)) + Ac = snp.array(np.random.randn(N, M).astype(np.float32)) A = linop.MatrixOperator(Ac) - x = jax.device_put(np.random.randn(M).astype(np.float32)) + x = snp.array(np.random.randn(M).astype(np.float32)) b = Ac.dot(x) tol = 1e-7 try: