From 29602b2dce3d4e95eecb2f5cc25f3a20d854a597 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 5 Oct 2023 06:24:51 -0600 Subject: [PATCH 01/11] Rename ATADSolver to MatrixATADSolver --- CHANGES.rst | 3 ++- scico/optimize/_admmaux.py | 14 ++++++++------ scico/solver.py | 3 +-- scico/test/test_solver.py | 6 +++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ebc467893..33adfae0a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- -• No significant changes yet. +• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.16. diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 7a0ceb0b2..72c232690 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -31,7 +31,7 @@ from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray from scico.numpy.util import ensure_on_device, is_real_dtype -from scico.solver import ATADSolver, ConvATADSolver +from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -296,14 +296,14 @@ class MatrixSubproblemSolver(LinearSubproblemSolver): \mb{u}^{(k)}_i) \;, which is solved by factorization of the left hand side of the - equation, using :class:`.ATADSolver`. + equation, using :class:`.MatrixATADSolver`. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. solve_kwargs (dict): Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None): @@ -313,7 +313,7 @@ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, A check_solve: If ``True``, compute solver accuracy after each solve. solve_kwargs: Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ self.check_solve = check_solve default_solve_kwargs = {"cho_factor": False} @@ -352,7 +352,7 @@ def internal_init(self, admm: soa.ADMM): Csum = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] ) - self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs) + self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs) def solve(self, x0: Array) -> Array: """Solve the ADMM step. @@ -775,7 +775,9 @@ def compute_rhs(self) -> Union[Array, BlockArray]: C0 = self.admm.C_list[0] rhs = snp.zeros(C0.input_shape, C0.input_dtype) omega = self.admm.g_list[0].scale - omega_list = [2.0 * omega,] + [ + omega_list = [ + 2.0 * omega, + ] + [ 1.0, ] * (len(self.admm.C_list) - 1) for omegai, rhoi, Ci, zi, ui in zip( diff --git a/scico/solver.py b/scico/solver.py index 5f0994246..e2f930cf0 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -267,7 +267,6 @@ def minimize_scalar( tol: Optional[float] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: - """Minimization of scalar function of one variable. Wrapper around :func:`scipy.optimize.minimize_scalar`. @@ -579,7 +578,7 @@ def golden( return r -class ATADSolver: +class MatrixATADSolver: r"""Solver for linear system involving a symmetric product plus a diagonal. Solve a linear system of the form diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index d5b179b62..f220482df 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -319,7 +319,7 @@ def test_solve_atai(cho_factor, wide, weighted, alpha): D = alpha * snp.ones((A.shape[1],)) ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1]) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -338,7 +338,7 @@ def test_solve_aati(cho_factor, wide, alpha): D = alpha * snp.ones((A.shape[0],)) AATD = A @ A.T + alpha * snp.identity(A.shape[0]) b = AATD @ x0 - slv = solver.ATADSolver(A.T, D) + slv = solver.MatrixATADSolver(A.T, D) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -365,7 +365,7 @@ def test_solve_atad(cho_factor, wide, vector): D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU ATAD = A.T @ A + snp.diag(D) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 assert slv.accuracy(x1, b) < 5e-5 From b3ecc756e4ad21f4b2190adf4e330e9c059cf7dc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 5 Oct 2023 06:39:47 -0600 Subject: [PATCH 02/11] Improve docs --- scico/solver.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/scico/solver.py b/scico/solver.py index e2f930cf0..299cbcb10 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -595,12 +595,18 @@ class MatrixATADSolver: where :math:`A \in \mbb{R}^{M \times N}`, :math:`W \in \mbb{R}^{M \times M}` and - :math:`D \in \mbb{R}^{N \times N}`. The solution is computed by - factorization of matrix :math:`A^T W A + D` and solution via Gaussian - elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. - :math:`A W A^T` is smaller than :math:`A^T W A`), then - :math:`A W A^T + D` is factorized and the original problem is solved - via the Woodbury matrix identity + :math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of + :class:`.MatrixOperator` or an array; :math:`D` must be an instance + of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and + :math:`W`, if specified, must be an instance of :class:`.Diagonal` + or an array. + + + The solution is computed by factorization of matrix + :math:`A^T W A + D` and solution via Gaussian elimination. If + :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is + smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized + and the original problem is solved via the Woodbury matrix identity .. math:: From cd4c14bfeb3d8466d34a875bc8cdfba4f19f688a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 5 Oct 2023 06:51:33 -0600 Subject: [PATCH 03/11] Add error checking on input types --- scico/solver.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/scico/solver.py b/scico/solver.py index 299cbcb10..ab64f8c6c 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -715,14 +715,27 @@ def __init__( """ if isinstance(A, MatrixOperator): A = A.to_array() + elif not isinstance(A, Array): + raise TypeError( + f"Operator A is required to be a MatrixOperator or an array; got a {type(A)}." + ) if isinstance(D, MatrixOperator): D = D.to_array() elif isinstance(D, Diagonal): D = D.diagonal + elif not isinstance(D, Array): + raise TypeError( + "Operator D is required to be a MatrixOperator, Diagonal, or an array; " + f"got a {type(D)}." + ) if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal + elif not isinstance(W, Array): + raise TypeError( + f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}." + ) self.A = A self.D = D self.W = W From 7ecb6ce0b9da2d79cad45a709cd9b119f2e30dde Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 6 Oct 2023 05:32:36 -0600 Subject: [PATCH 04/11] Another jaxlb/jax version bump --- CHANGES.rst | 2 +- requirements.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 33adfae0a..ee55f3316 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,7 +7,7 @@ Version 0.0.5 (unreleased) ---------------------------- • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.16. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.17. diff --git a/requirements.txt b/requirements.txt index e62fcb7cc..f9cbc4426 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.16 -jax>=0.4.3,<=0.4.16 +jaxlib>=0.4.3,<=0.4.17 +jax>=0.4.3,<=0.4.17 flax>=0.6.1,<=0.6.9 bm3d>=4.0.0 bm4d>=4.2.2 From d0526e7634e0a8cf9431a4d19f57bb46fb23f3fb Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 9 Oct 2023 01:20:30 -0600 Subject: [PATCH 05/11] Fix black formatting --- scico/optimize/_admmaux.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 72c232690..7a1c8710c 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -775,9 +775,7 @@ def compute_rhs(self) -> Union[Array, BlockArray]: C0 = self.admm.C_list[0] rhs = snp.zeros(C0.input_shape, C0.input_dtype) omega = self.admm.g_list[0].scale - omega_list = [ - 2.0 * omega, - ] + [ + omega_list = [2.0 * omega,] + [ 1.0, ] * (len(self.admm.C_list) - 1) for omegai, rhoi, Ci, zi, ui in zip( From b7fb076a929d1941eece526bbeb553da99963af6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 23 Oct 2023 09:22:48 -0600 Subject: [PATCH 06/11] Fix short docstring --- scico/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/solver.py b/scico/solver.py index ab64f8c6c..13ae99163 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -814,7 +814,7 @@ def accuracy(self, x: Array, b: Array) -> float: class ConvATADSolver: - r"""Solver for sum of convolutions plus diagonal linear system. + r"""Solver for a linear system involving a sum of convolutions. Solve a linear system of the form From e4a33eb5c5fbbbc0ad872d0f7975f358260dd8fe Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 23 Oct 2023 09:26:16 -0600 Subject: [PATCH 07/11] Fix short docstring --- scico/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/solver.py b/scico/solver.py index 13ae99163..21d49a0fb 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -579,7 +579,7 @@ def golden( class MatrixATADSolver: - r"""Solver for linear system involving a symmetric product plus a diagonal. + r"""Solver for linear system involving a symmetric product. Solve a linear system of the form From f45258573ba4a5fde0b14ead88cad64c0534cf97 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 24 Oct 2023 13:05:00 -0600 Subject: [PATCH 08/11] Implement __array__ method in MatrixOperator --- scico/linop/_matrix.py | 13 ++++++----- scico/numpy/util.py | 15 +++++++++++++ scico/solver.py | 38 ++++++++++++++++++--------------- scico/test/linop/test_matrix.py | 6 ++++-- 4 files changed, 46 insertions(+), 26 deletions(-) diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 0f6f21daa..951c6957e 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -17,9 +17,9 @@ import numpy as np -import jax import jax.numpy as jnp from jax.dtypes import result_type +from jax.typing import ArrayLike import scico.numpy as snp @@ -65,7 +65,7 @@ def wrapper(a, b): class MatrixOperator(LinearOperator): """Linear operator implementing matrix multiplication.""" - def __init__(self, A: snp.Array, input_cols: int = 0): + def __init__(self, A: ArrayLike, input_cols: int = 0): """ Args: A: Dense array. The action of the created @@ -80,17 +80,16 @@ def __init__(self, A: snp.Array, input_cols: int = 0): self.A: snp.Array #: Dense array implementing this matrix # if A is an ndarray, make sure it gets converted to a jax array - if isinstance(A, jnp.ndarray): - self.A = A - elif isinstance(A, np.ndarray): - self.A = jax.device_put(A) # TODO: ensure_on_device? - else: + if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") + self.A = jnp.array(A) # Can only do rank-2 arrays if A.ndim != 2: raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.") + self.__array__ = A.__array__ # enables jnp.array(H) + if input_cols == 0: input_shape = A.shape[1] output_shape = A.shape[0] diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 0e776038b..a9e31a00d 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -204,6 +204,21 @@ def shape_to_size(shape: Union[Shape, BlockShape]) -> int: return prod(shape) +def is_arraylike(x: Any) -> bool: + """Check if input is of type :class:`jax.ArrayLike`. + + `isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10, + see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices. + + Args: + x: Object to be tested. + + Returns: + ``True`` if `x` is an ArrayLike, ``False`` otherwise. + """ + return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x) + + def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. diff --git a/scico/solver.py b/scico/solver.py index 21d49a0fb..59f0cf767 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -54,6 +54,7 @@ import jax import jax.experimental.host_callback as hcb +import jax.numpy as jnp import jax.scipy.linalg as jsl import scico.numpy as snp @@ -260,7 +261,7 @@ def fun(x0): def minimize_scalar( func: Callable, - bracket: Optional[Union[Sequence[float]]] = None, + bracket: Optional[Sequence[float]] = None, bounds: Optional[Sequence[float]] = None, args: Union[Tuple, Tuple[Any]] = (), method: str = "brent", @@ -703,8 +704,12 @@ def __init__( r""" Args: A: Matrix :math:`A`. - D: Matrix :math:`D`. - W: Matrix :math:`W`. + D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`, + specifies the 2D matrix :math:`D`. If 1D array or + :class:`Diagonal`, specifies the diagonal elements + of :math:`D`. + W: Matrix :math:`W`. Specifies the diagonal elements of + :math:`W`. Defaults to ones. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper @@ -713,29 +718,28 @@ def __init__( check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. """ - if isinstance(A, MatrixOperator): - A = A.to_array() - elif not isinstance(A, Array): - raise TypeError( - f"Operator A is required to be a MatrixOperator or an array; got a {type(A)}." - ) - if isinstance(D, MatrixOperator): - D = D.to_array() - elif isinstance(D, Diagonal): + A = jnp.array(A) + + if isinstance(D, Diagonal): D = D.diagonal - elif not isinstance(D, Array): - raise TypeError( - "Operator D is required to be a MatrixOperator, Diagonal, or an array; " - f"got a {type(D)}." - ) + if not D.ndim == 1: + raise ValueError("If Diagonal, D should have a 1D diagonal.") + else: + D = jnp.array(D) + if not D.ndim in [1, 2]: + raise ValueError("If matrix, D should be 1D or 2D.") + if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal + if not W.ndim == 1: + raise ValueError("If Diagonal, W should have a 1D diagonal.") elif not isinstance(W, Array): raise TypeError( f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}." ) + self.A = A self.D = D self.W = W diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index 3f00b2c7a..178c1fce5 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -22,7 +22,6 @@ def setup_method(self, method): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_eval(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -38,7 +37,6 @@ def test_eval(self, matrix_shape, input_dtype, input_cols): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_adjoint(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -262,6 +260,10 @@ def test_to_array(self): assert isinstance(A_array, np.ndarray) np.testing.assert_allclose(A_array, A) + A_array = jnp.array(Ao) + assert isinstance(A_array, jax.Array) + np.testing.assert_allclose(A_array, A) + @pytest.mark.parametrize("ord", ["fro", 2]) @pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("keepdims", [True, False]) From fe977874063f30ff93cf6141209520769042e60f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 25 Oct 2023 18:10:39 -0600 Subject: [PATCH 09/11] Docstring edits --- scico/solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/solver.py b/scico/solver.py index 59f0cf767..a4c7e5875 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -709,7 +709,7 @@ def __init__( :class:`Diagonal`, specifies the diagonal elements of :math:`D`. W: Matrix :math:`W`. Specifies the diagonal elements of - :math:`W`. Defaults to ones. + :math:`W`. Defaults to an array with unit entries. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper @@ -727,7 +727,7 @@ def __init__( else: D = jnp.array(D) if not D.ndim in [1, 2]: - raise ValueError("If matrix, D should be 1D or 2D.") + raise ValueError("If MatrixOperator, D should be 1D or 2D.") if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) From 5e293c1f852f3c0fed8258e699008b9ff7f81d49 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 25 Oct 2023 18:14:57 -0600 Subject: [PATCH 10/11] Bump max jaxlib/jax versions --- CHANGES.rst | 2 +- requirements.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ee55f3316..a413f53cb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,7 +7,7 @@ Version 0.0.5 (unreleased) ---------------------------- • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.17. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19. diff --git a/requirements.txt b/requirements.txt index 7c98f1fb8..68ab1ebac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.17 -jax>=0.4.3,<=0.4.17 +jaxlib>=0.4.3,<=0.4.19 +jax>=0.4.3,<=0.4.19 flax>=0.6.1,<=0.6.9 svmbir>=0.3.3 pyabel>=0.9.0 From f1ab9748e2ec9257bf6a70c6cee34cadea26e1b6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 26 Oct 2023 14:04:40 -0600 Subject: [PATCH 11/11] Fix error message --- scico/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/solver.py b/scico/solver.py index a4c7e5875..f93cd710e 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -727,7 +727,7 @@ def __init__( else: D = jnp.array(D) if not D.ndim in [1, 2]: - raise ValueError("If MatrixOperator, D should be 1D or 2D.") + raise ValueError("If array or MatrixOperator, D should be 1D or 2D.") if W is None: W = snp.ones(A.shape[0], dtype=A.dtype)