diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 09a7cbed5..aa530a6ef 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -191,7 +191,6 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: `__. """ - # Just use the CPU FBP alg for now; hitting memory issues with GPU one. def f(sino): sino = _ensure_writeable(sino) sino_id = astra.data2d.create("-sino", self.proj_geom, sino) @@ -200,7 +199,7 @@ def f(sino): rec_id = astra.data2d.create("-vol", self.vol_geom) # start to populate config - cfg = astra.astra_dict("FBP") + cfg = astra.astra_dict("FBP_CUDA" if self.device == "gpu" else "FBP") cfg["ReconstructionDataId"] = rec_id cfg["ProjectorId"] = self.proj_id cfg["ProjectionDataId"] = sino_id diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 096e27186..d352eedfe 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -371,15 +371,15 @@ class CircularConvolveSolver(LinearSubproblemSolver): r"""Solver for linear operators diagonalized in the DFT domain. Specialization of :class:`.LinearSubproblemSolver` for the case - where :code:`f` is an instance of :class:`.SquaredL2Loss`, the - forward operator :code:`f.A` is either an instance of - :class:`.Identity` or :class:`.CircularConvolve`, and the - :code:`C_i` are all shift invariant linear operators, examples of - which include instances of :class:`.Identity` as well as some - instances (depending on initializer parameters) of - :class:`.CircularConvolve` and :class:`.FiniteDifference`. - None of the instances of :class:`.CircularConvolve` may sum over any - of their axes. + where :code:`f` is ``None``, or an instance of + :class:`.SquaredL2Loss` with a forward operator :code:`f.A` that is + either an instance of :class:`.Identity` or + :class:`.CircularConvolve`, and the :code:`C_i` are all shift + invariant linear operators, examples of which include instances of + :class:`.Identity` as well as some instances (depending on + initializer parameters) of :class:`.CircularConvolve` and + :class:`.FiniteDifference`. None of the instances of + :class:`.CircularConvolve` may sum over any of their axes. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is @@ -388,11 +388,29 @@ class CircularConvolveSolver(LinearSubproblemSolver): equation to be solved. """ - def __init__(self): - """Initialize a :class:`CircularConvolveSolver` object.""" + def __init__(self, ndims: Optional[int] = None): + """Initialize a :class:`CircularConvolveSolver` object. + + Args: + ndims: Number of trailing dimensions of the input and kernel + involved in the :class:`.CircularConvolve` convolutions. + In most cases this value is automatically determined from + the optimization problem specification, but this is not + possible when :code:`f` is ``None`` and none of the + :code:`C_i` are of type :class:`.CircularConvolve`. When + not ``None``, this parameter overrides the automatic + mechanism. + """ + self.ndims = ndims def internal_init(self, admm: soa.ADMM): - if admm.f is not None: + if admm.f is None: + is_cc = [isinstance(C, CircularConvolve) for C in admm.C_list] + if any(is_cc): + auto_ndims = admm.C_list[is_cc.index(True)].ndims + else: + auto_ndims = None + else: if not isinstance(admm.f, SquaredL2Loss): raise TypeError( "CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; " @@ -403,7 +421,10 @@ def internal_init(self, admm: soa.ADMM): "CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve " f"or scico.linop.Identity; got {type(admm.f.A)}." ) + auto_ndims = admm.f.A.ndims if isinstance(admm.f.A, CircularConvolve) else None + if self.ndims is None: + self.ndims = auto_ndims super().internal_init(admm) self.real_result = is_real_dtype(admm.C_list[0].input_dtype) @@ -411,12 +432,16 @@ def internal_init(self, admm: soa.ADMM): # All of the C operators are assumed to be linear and shift invariant # but this is not checked. lhs_op_list = [ - rho * CircularConvolve.from_operator(C.gram_op) + rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list, admm.C_list) ] A_lhs = reduce(lambda a, b: a + b, lhs_op_list) if self.admm.f is not None: - A_lhs += 2.0 * admm.f.scale * CircularConvolve.from_operator(admm.f.A.gram_op) + A_lhs += ( + 2.0 + * admm.f.scale + * CircularConvolve.from_operator(admm.f.A.gram_op, ndims=self.ndims) + ) self.A_lhs = A_lhs diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index 6ed0b1efd..de8302331 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -122,6 +122,13 @@ def test_adjoint_typical_input(testobj): adjoint_test(A, x=x, rtol=get_tol()) +def test_fbp(testobj): + x = testobj.A.fbp(testobj.y) + # Test for a bug (related to calling the Astra CPU FBP implementation + # when using a FPU device) that resulted in a constant zero output. + assert np.sum(np.abs(x)) > 0.0 + + def test_jit_in_DiagonalStack(): """See https://github.com/lanl/scico/issues/331""" N = 10 diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index 8795efb9a..6d61fd864 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -363,23 +363,29 @@ def test_admm_quadratic_matrix(self): assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 +@pytest.mark.parametrize("extra_axis", (False, True)) +@pytest.mark.parametrize("center", (None, [-1.0, 2.5])) class TestCircularConvolveSolve: - def setup_method(self, method): + + @pytest.fixture(scope="function", autouse=True) + def setup_and_teardown(self, extra_axis, center): np.random.seed(12345) Nx = 8 - x = np.pad(np.ones((Nx, Nx), dtype=np.float32), Nx) + x = snp.pad(snp.ones((Nx, Nx), dtype=np.float32), Nx) Npsf = 3 psf = snp.ones((Npsf, Npsf), dtype=np.float32) / (Npsf**2) + if extra_axis: + x = x[np.newaxis] + psf = psf[np.newaxis] self.A = linop.CircularConvolve( - h=psf, - input_shape=x.shape, - input_dtype=np.float32, + h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32, h_center=center ) self.y = self.A(x) λ = 1e-2 self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g_list = [λ * functional.L1Norm()] self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)] + yield def test_admm(self): maxiter = 50 @@ -406,6 +412,60 @@ def test_admm(self): x0=self.A.adj(self.y), subproblem_solver=CircularConvolveSolver(), ) + assert admm_dft.subproblem_solver.A_lhs.ndims == 2 + x_dft = admm_dft.solve() + np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0) + assert metric.mse(x_lin, x_dft) < 1e-9 + + +@pytest.mark.parametrize("with_cconv", (False, True)) +class TestSpecialCaseCircularConvolveSolve: + + @pytest.fixture(scope="function", autouse=True) + def setup_and_teardown(self, with_cconv): + np.random.seed(12345) + Nx = 8 + x = snp.pad(snp.ones((1, Nx, Nx), dtype=np.float32), Nx) + if with_cconv: + Npsf = 3 + psf = snp.ones((1, Npsf, Npsf), dtype=np.float32) / (Npsf**2) + C0 = linop.CircularConvolve(h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32) + else: + C0 = linop.FiniteDifference(input_shape=x.shape, axes=(1, 2), circular=True) + C1 = linop.Identity(input_shape=x.shape) + self.y = C0(x) + self.g_list = [loss.SquaredL2Loss(y=self.y), functional.L2Norm()] + self.C_list = [C0, C1] + self.with_cconv = with_cconv + yield + + def test_admm(self): + maxiter = 50 + ρ = 1e-1 + rho_list = [ρ, ρ] + admm_lin = ADMM( + f=None, + g_list=self.g_list, + C_list=self.C_list, + rho_list=rho_list, + maxiter=maxiter, + itstat_options={"display": False}, + x0=self.C_list[0].adj(self.y), + subproblem_solver=LinearSubproblemSolver(), + ) + x_lin = admm_lin.solve() + ndims = None if self.with_cconv else 2 + admm_dft = ADMM( + f=None, + g_list=self.g_list, + C_list=self.C_list, + rho_list=rho_list, + maxiter=maxiter, + itstat_options={"display": False}, + x0=self.C_list[0].adj(self.y), + subproblem_solver=CircularConvolveSolver(ndims=ndims), + ) + assert admm_dft.subproblem_solver.A_lhs.ndims == 2 x_dft = admm_dft.solve() np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0) assert metric.mse(x_lin, x_dft) < 1e-9