From c06bdb53b3ce6d138f97ea28623bb24f154f040e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 25 Jun 2024 13:20:01 -0600 Subject: [PATCH] Add diagonal operator mapping base operator over an array axis (#521) * Add diagonal operators constructed by replication of a base operator * Add documentation * Improve error checking * Fix allowed range for input_axis and allow negative values * Rename class aliases as per PR review comments --- scico/linop/__init__.py | 3 +- scico/linop/_stack.py | 86 ++++++++++++++++++++- scico/operator/__init__.py | 5 +- scico/operator/_stack.py | 108 ++++++++++++++++++++++++++- scico/test/linop/test_linop_stack.py | 23 +++++- scico/test/operator/test_op_stack.py | 54 +++++++++++++- 6 files changed, 269 insertions(+), 10 deletions(-) diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 2e91e0096..f88422a00 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -17,7 +17,7 @@ from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator -from ._stack import DiagonalStack, VerticalStack, linop_over_axes +from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes from ._util import jacobian, operator_norm, power_iteration, valid_adjoint from .xray import Parallel2dProjector, XRayTransform @@ -29,6 +29,7 @@ "FiniteDifference", "SingleAxisFiniteDifference", "Identity", + "DiagonalReplicated", "VerticalStack", "DiagonalStack", "MatrixOperator", diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 514c208b3..1f443073d 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -14,14 +14,15 @@ import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import normalize_axes -from scico.operator._stack import DiagonalStack as DStack -from scico.operator._stack import VerticalStack as VStack +from scico.operator._stack import DiagonalReplicated as DiagonalReplicatedOperator +from scico.operator._stack import DiagonalStack as DiagonalStackOperator +from scico.operator._stack import VerticalStack as VerticalStackOperator from scico.typing import Axes, Shape from ._linop import LinearOperator -class VerticalStack(VStack, LinearOperator): +class VerticalStack(VerticalStackOperator, LinearOperator): r"""A vertical stack of linear operators. Given linear operators :math:`A_1, A_2, \dots, A_N`, create the @@ -71,7 +72,7 @@ def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore -class DiagonalStack(DStack, LinearOperator): +class DiagonalStack(DiagonalStackOperator, LinearOperator): r"""A diagonal stack of linear operators. Given linear operators :math:`A_1, A_2, \dots, A_N`, create the @@ -146,6 +147,83 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type return snp.blockarray(result) +class DiagonalReplicated(DiagonalReplicatedOperator, LinearOperator): + r"""A diagonal stack constructed from a single linear operator. + + Given linear operator :math:`A`, create the linear operator + + .. math:: + H = + \begin{pmatrix} + A & 0 & \ldots & 0\\ + 0 & A & \ldots & 0\\ + \vdots & \vdots & \ddots & \vdots\\ + 0 & 0 & \ldots & A \\ + \end{pmatrix} \qquad + \text{such that} \qquad + H + \begin{pmatrix} + \mb{x}_1 \\ + \mb{x}_2 \\ + \vdots \\ + \mb{x}_N \\ + \end{pmatrix} + = + \begin{pmatrix} + A(\mb{x}_1) \\ + A(\mb{x}_2) \\ + \vdots \\ + A(\mb{x}_N) \\ + \end{pmatrix} \;. + + The application of :math:`A` to each component :math:`\mb{x}_k` is + computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape + for linear operator :math:`A` should exclude the array axis on which + :math:`A` is replicated to form :math:`H`. For example, if :math:`A` + has input shape `(3, 4)` and :math:`H` is constructed to replicate + on axis 0 with 2 replicates, the input shape of :math:`H` will be + `(2, 3, 4)`. + + Linear operators taking :class:`.BlockArray` input are not supported. + """ + + def __init__( + self, + op: LinearOperator, + replicates: int, + input_axis: int = 0, + output_axis: Optional[int] = None, + map_type: str = "auto", + **kwargs, + ): + """ + Args: + op: Linear operator to replicate. + replicates: Number of replicates of `op`. + input_axis: Input axis over which `op` should be replicated. + output_axis: Index of replication axis in output array. + If ``None``, the input replication axis is used. + map_type: If "pmap" or "vmap", apply replicated mapping using + :func:`jax.pmap` or :func:`jax.vmap` respectively. If + "auto", use :func:`jax.pmap` if sufficient devices are + available for the number of replicates, otherwise use + :func:`jax.vmap`. + """ + if not isinstance(op, LinearOperator): + raise TypeError("Argument op must be of type LinearOperator.") + + super().__init__( + op, + replicates, + input_axis=input_axis, + output_axis=output_axis, + map_type=map_type, + **kwargs, + ) + + self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis) + + def linop_over_axes( linop: type[LinearOperator], input_shape: Shape, diff --git a/scico/operator/__init__.py b/scico/operator/__init__.py index fee512369..8d3b01928 100644 --- a/scico/operator/__init__.py +++ b/scico/operator/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,11 +13,12 @@ from ._operator import Operator from .biconvolve import BiConvolve from ._func import operator_from_function, Abs, Angle, Exp -from ._stack import DiagonalStack, VerticalStack +from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated __all__ = [ "Operator", "BiConvolve", + "DiagonalReplicated", "DiagonalStack", "VerticalStack", "operator_from_function", diff --git a/scico/operator/_stack.py b/scico/operator/_stack.py index 9e16f05ea..8d20cd94b 100644 --- a/scico/operator/_stack.py +++ b/scico/operator/_stack.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2023 by SCICO Developers +# Copyright (C) 2023-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,6 +13,8 @@ import numpy as np +import jax + from typing_extensions import TypeGuard import scico.numpy as snp @@ -234,3 +236,107 @@ def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: if self.collapse_output: return snp.stack(result) return snp.blockarray(result) + + +class DiagonalReplicated(Operator): + r"""A diagonal stack constructed from a single operator. + + Given operator :math:`A`, create the operator :math:`H` such that + + .. math:: + H \left( + \begin{pmatrix} + \mb{x}_1 \\ + \mb{x}_2 \\ + \vdots \\ + \mb{x}_N \\ + \end{pmatrix} \right) + = + \begin{pmatrix} + A(\mb{x}_1) \\ + A(\mb{x}_2) \\ + \vdots \\ + A(\mb{x}_N) \\ + \end{pmatrix} \;. + + The application of :math:`A` to each component :math:`\mb{x}_k` is + computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape + for operator :math:`A` should exclude the array axis on which + :math:`A` is replicated to form :math:`H`. For example, if :math:`A` + has input shape `(3, 4)` and :math:`H` is constructed to replicate + on axis 0 with 2 replicates, the input shape of :math:`H` will be + `(2, 3, 4)`. + + Operators taking :class:`.BlockArray` input are not supported. + """ + + def __init__( + self, + op: Operator, + replicates: int, + input_axis: int = 0, + output_axis: Optional[int] = None, + map_type: str = "auto", + **kwargs, + ): + """ + Args: + op: Operator to replicate. + replicates: Number of replicates of `op`. + input_axis: Input axis over which `op` should be replicated. + output_axis: Index of replication axis in output array. + If ``None``, the input replication axis is used. + map_type: If "pmap" or "vmap", apply replicated mapping using + :func:`jax.pmap` or :func:`jax.vmap` respectively. If + "auto", use :func:`jax.pmap` if sufficient devices are + available for the number of replicates, otherwise use + :func:`jax.vmap`. + """ + if map_type not in ["auto", "pmap", "vmap"]: + raise ValueError("Argument map_type must be one of 'auto', 'pmap, or 'vmap'.") + if input_axis < 0: + input_axis = len(op.input_shape) + 1 + input_axis + if input_axis < 0 or input_axis > len(op.input_shape): + raise ValueError( + "Argument input_axis must be positive and less than the number of axes " + "in the input shape of op." + ) + if is_nested(op.input_shape): + raise ValueError("Argument op may not be an Operator taking BlockArray input.") + if is_nested(op.output_shape): + raise ValueError("Argument op may not be an Operator with BlockArray output.") + self.op = op + self.replicates = replicates + self.input_axis = input_axis + self.output_axis = self.input_axis if output_axis is None else output_axis + + if map_type == "auto": + self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap + else: + if map_type == "pmap" and replicates > jax.device_count(): + raise ValueError( + "Requested pmap mapping but number of replicates exceeds device count." + ) + else: + self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap + + eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis) + + input_shape = ( + op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :] + ) + output_shape = ( + op.output_shape[0 : self.output_axis] + + (replicates,) + + op.output_shape[self.output_axis :] + ) + + super().__init__( + input_shape=input_shape, # type: ignore + output_shape=output_shape, # type: ignore + eval_fn=eval_fn, + input_dtype=op.input_dtype, + output_dtype=op.output_dtype, + jit=False, + **kwargs, + ) diff --git a/scico/test/linop/test_linop_stack.py b/scico/test/linop/test_linop_stack.py index 37f77d4c8..0a2589d85 100644 --- a/scico/test/linop/test_linop_stack.py +++ b/scico/test/linop/test_linop_stack.py @@ -5,8 +5,16 @@ import pytest import scico.numpy as snp -from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack +from scico.linop import ( + Convolve, + DiagonalReplicated, + DiagonalStack, + Identity, + Sum, + VerticalStack, +) from scico.operator import Abs +from scico.random import randn from scico.test.linop.test_linop import adjoint_test @@ -166,3 +174,16 @@ def test_output_collapse(self): H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (S1, S1) + + +class TestDiagonalReplicated: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + def test_adjoint(self): + x, key = randn((2, 3, 4), key=self.key) + A = Sum(x.shape[1:], axis=-1) + D = DiagonalReplicated(A, x.shape[0]) + y = D.T(D(x)) + np.testing.assert_allclose(y[0], A.T(A(x[0]))) + np.testing.assert_allclose(y[1], A.T(A(x[1]))) diff --git a/scico/test/operator/test_op_stack.py b/scico/test/operator/test_op_stack.py index c981cdf26..258fa83f0 100644 --- a/scico/test/operator/test_op_stack.py +++ b/scico/test/operator/test_op_stack.py @@ -5,7 +5,14 @@ import pytest import scico.numpy as snp -from scico.operator import Abs, DiagonalStack, Operator, VerticalStack +from scico.operator import ( + Abs, + DiagonalReplicated, + DiagonalStack, + Operator, + VerticalStack, +) +from scico.random import randn TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x))) TestOpB = Operator( @@ -140,3 +147,48 @@ def test_output_collapse(self): H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (A1.output_shape, A1.output_shape) + + +class TestDiagonalReplicated: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("map_type", ["auto", "vmap"]) + @pytest.mark.parametrize("input_axis", [0, 1]) + def test_map_auto_vmap(self, input_axis, map_type): + x, key = randn((2, 3, 4), key=self.key) + mapshape = (3, 4) if input_axis == 0 else (2, 4) + replicates = x.shape[input_axis] + A = Abs(mapshape) + D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type) + y = D(x) + assert y.shape[input_axis] == replicates + + @pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test") + def test_map_auto_pmap(self): + x, key = randn((2, 3, 4), key=self.key) + A = Abs(x.shape[1:]) + replicates = x.shape[0] + D = DiagonalReplicated(A, replicates, map_type="pmap") + y = D(x) + assert y.shape[0] == replicates + + def test_input_axis(self): + # Ensure that operators can be stacked on final axis + x, key = randn((2, 3, 4), key=self.key) + A = Abs(x.shape[0:2]) + replicates = x.shape[2] + D = DiagonalReplicated(A, replicates, input_axis=2) + y = D(x) + assert y.shape == (2, 3, 4) + D = DiagonalReplicated(A, replicates, input_axis=-1) + y = D(x) + assert y.shape == (2, 3, 4) + + def test_output_axis(self): + x, key = randn((2, 3, 4), key=self.key) + A = Abs(x.shape[1:]) + replicates = x.shape[0] + D = DiagonalReplicated(A, replicates, output_axis=1) + y = D(x) + assert y.shape == (3, 2, 4)