diff --git a/CHANGES.rst b/CHANGES.rst index 40263c7e7..e9e3e0301 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,6 +9,8 @@ Version 0.0.6 (unreleased) • Significant changes to ``linop.xray.astra`` API. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. +• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``, + ``linop.CylindricalGradient``, and ``linop.SphericalGradient``. • Rename ``scico.numpy.util.parse_axes`` to ``scico.numpy.util.normalize_axes``. • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to diff --git a/data b/data index cb97dd02c..3bc228c83 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit cb97dd02cfd96659804c74e7a5f9185440cec3ce +Subproject commit 3bc228c83d14b7bb5c3c64df7f94112c836ca8b7 diff --git a/docs/docs_requirements.txt b/docs/docs_requirements.txt index 1e55847bf..fbd4cafc4 100644 --- a/docs/docs_requirements.txt +++ b/docs/docs_requirements.txt @@ -3,13 +3,13 @@ sphinx>=5.0.0 sphinxcontrib-napoleon sphinxcontrib-bibtex sphinx-autodoc-typehints -furo +furo>=2024.5.6 jinja2<3.1.0 # temporary fix for jinja2/nbconvert bug traitlets!=5.2.2 # temporary fix for ipython/traitlets#741 nbsphinx ipython ipython_genutils py2jn -pygraphviz>=1.7 +pygraphviz>=1.9 pandoc docutils>=0.18 diff --git a/docs/source/conf/15-theme.py b/docs/source/conf/15-theme.py index 35ad6d669..8d20c3199 100644 --- a/docs/source/conf/15-theme.py +++ b/docs/source/conf/15-theme.py @@ -6,7 +6,7 @@ html_theme = "furo" html_theme_options = { - "top_of_page_button": None, + "top_of_page_buttons": [], # "sidebar_hide_name": True, } diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 41a3df9e1..13572f3d7 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -84,6 +84,7 @@ Miscellaneous examples/demosaic_ppp_bm3d_admm examples/superres_ppp_dncnn_admm examples/denoise_l1tv_admm + examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi @@ -145,6 +146,7 @@ Total Variation examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/denoise_l1tv_admm + examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi @@ -272,6 +274,7 @@ PDHG :maxdepth: 1 examples/ct_svmbir_tv_multi + examples/denoise_ptv_pdhg examples/denoise_tv_multi examples/denoise_cplx_tv_pdhg diff --git a/docs/source/references.bib b/docs/source/references.bib index a150a9b07..257f24287 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -362,6 +362,15 @@ @Book {goodman-2005-fourier edition = 3 } +@Misc {hossein-2024-total, + title = {Total Variation Regularization for Tomographic + Reconstruction of Cylindrically Symmetric Objects}, + author = {Maliha Hossain and Charles A. Bouman and Brendt + Wohlberg}, + year = 2024, + eprint = {2406.17928} +} + @Article {huber-1964-robust, doi = {10.1214/aoms/1177703732}, year = 1964, @@ -776,6 +785,7 @@ @Article {zhang-2021-plug pages = {6360--6376} } + @Article {zhou-2006-adaptive, author = {Bin Zhou and Li Gao and Yu-Hong Dai}, title = {Gradient Methods with Adaptive Step-Sizes}, diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 8c664c649..481e75ed0 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -101,6 +101,8 @@ Miscellaneous PPP (with DnCNN) Image Superresolution `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ @@ -192,6 +194,8 @@ Total Variation Deconvolution Microscopy (All Channels) `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ @@ -359,6 +363,8 @@ PDHG `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_cplx_tv_pdhg.py `_ diff --git a/examples/scripts/denoise_ptv_pdhg.py b/examples/scripts/denoise_ptv_pdhg.py new file mode 100644 index 000000000..ec5db49c9 --- /dev/null +++ b/examples/scripts/denoise_ptv_pdhg.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Polar Total Variation Denoising (PDHG) +====================================== + +This example compares denoising via standard isotropic total +variation (TV) regularization :cite:`rudin-1992-nonlinear` +:cite:`goldstein-2009-split` and a variant based on local polar +coordinates, as described in :cite:`hossein-2024-total`. It solves the +denoising problem + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} + \|_2^2 + \lambda R(\mathbf{x}) \;,$$ + +where $R$ is either the isotropic or polar TV regularizer, via the +primal–dual hybrid gradient (PDHG) algorithm. +""" + + +from xdesign import SiemensStar, discrete_phantom + +import scico.numpy as snp +import scico.random +from scico import functional, linop, loss, metric, plot +from scico.optimize import PDHG +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 256 # image size +phantom = SiemensStar(16) +x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) +x_gt = x_gt / x_gt.max() + + +""" +Add noise to create a noisy test image. +""" +σ = 0.75 # noise standard deviation +noise, key = scico.random.randn(x_gt.shape, seed=0) +y = x_gt + σ * noise + + +""" +Denoise with standard isotropic total variation. +""" +λ_std = 0.8e0 +f = loss.SquaredL2Loss(y=y) +g_std = λ_std * functional.L21Norm() + +# The append=0 option makes the results of horizontal and vertical finite +# differences the same shape, which is required for the L21Norm. +C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) +tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) +solver = PDHG( + f=f, + g=g_std, + C=C, + tau=tau, + sigma=sigma, + maxiter=200, + itstat_options={"display": True, "period": 10}, +) +print(f"Solving on {device_info()}\n") +solver.solve() +hist_std = solver.itstat_object.history(transpose=True) +x_std = solver.x +print() + + +""" +Denoise with polar total variation for comparison. +""" +# Tune the weight to give the same data fidelty as the isotropic case. +λ_plr = 1.2e0 +g_plr = λ_plr * functional.L1Norm() + +G = linop.PolarGradient(input_shape=x_gt.shape) +D = linop.Diagonal(snp.blockarray([0.3, 1.0]), input_shape=G.shape[0]) +C = D @ G + +tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) +solver = PDHG( + f=f, + g=g_plr, + C=C, + tau=tau, + sigma=sigma, + maxiter=200, + itstat_options={"display": True, "period": 10}, +) +solver.solve() +hist_plr = solver.itstat_object.history(transpose=True) +x_plr = solver.x +print() + + +""" +Compute and print the data fidelity. +""" +for x, name in zip((x_std, x_plr), ("Isotropic", "Polar")): + df = f(x) + print(f"Data fidelity for {(name + ' TV'):12}: {df:.2e} SNR: {metric.snr(x_gt, x):5.2f} dB") + + +""" +Plot results. +""" +plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) +fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) +plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) +plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) +plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) +fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) +fig.colorbar( + ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" +) +fig.suptitle("Denoising comparison") +fig.show() + +# zoomed version +fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) +plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) +plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) +plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) +ax[0, 0].set_xlim(N // 4, N // 4 + N // 2) +ax[0, 0].set_ylim(N // 4, N // 4 + N // 2) +fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) +fig.colorbar( + ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" +) +fig.suptitle("Denoising comparison (zoomed)") +fig.show() + + +fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(20, 5)) +plot.plot( + snp.vstack((hist_std.Objective, hist_plr.Objective)).T, + ptyp="semilogy", + title="Objective function", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[0], +) +plot.plot( + snp.vstack((hist_std.Prml_Rsdl, hist_plr.Prml_Rsdl)).T, + ptyp="semilogy", + title="Primal residual", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[1], +) +plot.plot( + snp.vstack((hist_std.Dual_Rsdl, hist_plr.Dual_Rsdl)).T, + ptyp="semilogy", + title="Dual residual", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[2], +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 8100f06e6..ee4973097 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -62,6 +62,7 @@ Miscellaneous - demosaic_ppp_bm3d_admm.py - superres_ppp_dncnn_admm.py - denoise_l1tv_admm.py + - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py @@ -114,6 +115,7 @@ Total Variation - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - denoise_l1tv_admm.py + - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py @@ -217,6 +219,7 @@ PDHG ^^^^ - ct_svmbir_tv_multi.py + - denoise_ptv_pdhg.py - denoise_tv_multi.py - denoise_cplx_tv_pdhg.py diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index f88422a00..d04104e81 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -15,6 +15,12 @@ from ._diag import Diagonal, Identity, ScaledIdentity from ._diff import FiniteDifference, SingleAxisFiniteDifference from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function +from ._grad import ( + CylindricalGradient, + PolarGradient, + ProjectedGradient, + SphericalGradient, +) from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes @@ -27,6 +33,10 @@ "DFT", "Diagonal", "FiniteDifference", + "ProjectedGradient", + "PolarGradient", + "CylindricalGradient", + "SphericalGradient", "SingleAxisFiniteDifference", "Identity", "DiagonalReplicated", diff --git a/scico/linop/_grad.py b/scico/linop/_grad.py new file mode 100644 index 000000000..416cc125d --- /dev/null +++ b/scico/linop/_grad.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2021-2024 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Non-Cartesian gradient linear operators.""" + + +# Needed to annotate a class method that returns the encapsulating class +# see https://www.python.org/dev/peps/pep-0563/ +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import scico.numpy as snp +from scico.numpy import Array, BlockArray +from scico.typing import BlockShape, DType, Shape + +from ._linop import LinearOperator + + +def diffstack(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """Compute the discrete difference along multiple axes. + + Apply :func:`snp.diff` along multiple axes, stacking the results on + a newly inserted axis at index 0. The `append` parameter of + :func:`snp.diff` is exploited to give output of the same length as + the input, which is achieved by zero-padding the output at the end + of each axis. + + + """ + if axis is None: + axis = tuple(range(x.ndim)) + elif isinstance(axis, int): + axis = (axis,) + dstack = [ + snp.diff( + x, + axis=ax, + append=x[tuple(slice(-1, None) if i == ax else slice(None) for i in range(x.ndim))], + ) + for ax in axis + ] + return snp.stack(dstack) + + +class ProjectedGradient(LinearOperator): + """Gradient projected onto local coordinate system. + + This class represents a linear operator that computes gradients of + arrays projected onto a local coordinate system that may differ at + every position in the array, as described in + :cite:`hossein-2024-total`. In the 2D illustration below :math:`x` + and :math:`y` represent the standard coordinate system defined by the + array axes, :math:`(g_x, g_y)` is the gradient vector within that + coordinate system, :math:`x'` and :math:`y'` are the local coordinate + axes, and :math:`(g_x', g_y')` is the gradient vector within the + local coordinate system. + + .. image:: /figures/projgrad.svg + :align: center + :alt: Figure illustrating projection of gradient onto local + coordinate system. + + Each of the local coordinate axes (e.g. :math:`x'` and :math:`y'` in + the illustration above) is represented by a separate array in the + `coord` tuple of arrays parameter of the class initializer. + + .. note:: + + This operator should not be confused with the Projected Gradient + optimization algorithm (a special case of Proximal Gradient), with + which it is unrelated. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + coord: Optional[Sequence[Union[Array, BlockArray]]] = None, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Defaults to + ``None``, in which case the gradient is computed along + all axes. + coord: A tuple of arrays, each of which specifies a local + coordinate axis direction. Each member of the tuple + should either be a :class:`jax.Array` or a + :class:`.BlockArray`. If it is the former, it should have + shape :math:`N \times M_0 \times M_1 \times \ldots`, + where :math:`N` is the number of axes specified by + parameter `axes`, and :math:`M_i` is the size of the + :math:`i^{\mrm{th}}` axis. If it is the latter, it should + consist of :math:`N` blocks, each of which has a shape + that is suitable for multiplication with an array of + shape :math:`M_0 \times M_1 \times \ldots`. If `coord` is + a singleton tuple, the result of applying the operator is + a :class:`jax.Array`; otherwise it consists of the + gradients for each of the local coordinate axes in + `coord` stacked into a :class:`.BlockArray`. If `coord` + is ``None``, which is the default, gradients are computed + in the standard axis-aligned coordinate system, and the + return type depends on the number of axes on which the + gradient is calculated, as specified explicitly or + implicitly via the `axes` parameter. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + if axes is None: + # If axes is None, set it to all axes in input shape. + self.axes = tuple(range(len(input_shape))) + else: + # Ensure no invalid axis indices specified. + if snp.any(np.array(axes) >= len(input_shape)): + raise ValueError( + "Invalid axes specified; all elements of `axes` must be less than " + f"len(input_shape)={len(input_shape)}." + ) + self.axes = axes + output_shape: Union[Shape, BlockShape] + if coord is None: + # If coord is None, output shape is determined by number of axes. + if len(self.axes) == 1: + output_shape = input_shape + else: + output_shape = (input_shape,) * len(self.axes) + else: + # If coord is not None, output shape is determined by number of coord arrays. + if len(coord) == 1: + output_shape = input_shape + else: + output_shape = (input_shape,) * len(coord) + self.coord = coord + self.cdiff = cdiff + super().__init__( + input_shape=input_shape, + output_shape=output_shape, + input_dtype=input_dtype, + output_dtype=input_dtype, + jit=jit, + ) + + def _eval(self, x: Array) -> Union[Array, BlockArray]: + + if self.cdiff: + grad = snp.gradient(x, axis=self.axes) + else: + grad = diffstack(x, axis=self.axes) + if self.coord is None: + # If coord attribute is None, just return gradients on specified axes. + if len(self.axes) == 1: + return grad + else: + return snp.blockarray(grad) + else: + # If coord attribute is not None, return gradients projected onto specified local + # coordinate systems. + projgrad = [sum([c[m] * grad[m] for m in range(len(grad))]) for c in self.coord] + if len(self.coord) == 1: + return projgrad[0] + else: + return snp.blockarray(projgrad) + + +class PolarGradient(ProjectedGradient): + """Gradient projected into polar coordinates. + + Compute gradients projected onto angular and/or radial axis + directions, as described in :cite:`hossein-2024-total`. Local + coordinate axes are illustrated in the figure below. + + .. plot:: figures/polargrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `angular` and `radial` is ``True``, the operator + output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + angular: bool = True, + radial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y)`, where :math:`i_x` and + :math:`i_y` are input array axes assigned to :math:`x` + and :math:`y` coordinates respectively. Defaults to + ``None``, in which case the axes are taken to be `(0, 1)`. + center: Center of the polar coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the input array. + angular: Flag indicating whether to compute gradients in the + angular (i.e. tangent to circles) direction. + radial: Flag indicating whether to compute gradients in the + radial (i.e. directed outwards from the origin) direction. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 2: + raise ValueError("Invalid input shape; input must have at least two axes.") + if axes is not None and len(axes) != 2: + raise ValueError("Invalid axes specified; exactly two axes must be specified.") + if not angular and not radial: + raise ValueError("At least one of angular and radial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] + theta = snp.arctan2(g0, g1) + # Re-order theta axes in case indices in axes parameter are not in increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + if len(input_shape) > 2: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + coord = [] + if angular: + coord.append(snp.blockarray([-snp.cos(theta), snp.sin(theta)])) + if radial: + coord.append(snp.blockarray([snp.sin(theta), snp.cos(theta)])) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + coord=coord, + cdiff=cdiff, + jit=jit, + ) + + +class CylindricalGradient(ProjectedGradient): + """Gradient projected into cylindrical coordinates. + + Compute gradients projected onto cylindrical coordinate axes, as + described in :cite:`hossein-2024-total`. The local coordinate axes + are illustrated in the figure below. + + .. plot:: figures/cylindgrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `angular`, `radial`, and `axial` is ``True``, the + operator output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + angular: bool = True, + radial: bool = True, + axial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, + :math:`i_y` and :math:`i_z` are input array axes assigned + to :math:`x`, :math:`y`, and :math:`z` coordinates + respectively. Defaults to ``None``, in which case the + axes are taken to be `(0, 1, 2)`. If an integer, this + operator returns a :class:`jax.Array`. If a tuple or + ``None``, the resulting arrays are stacked into a + :class:`.BlockArray`. + center: Center of the cylindrical coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the two polar axes of the + input array and at the zero index of the axial axis. + angular: Flag indicating whether to compute gradients in the + angular (i.e. tangent to circles) direction. + radial: Flag indicating whether to compute gradients in the + radial (i.e. directed outwards from the origin) direction. + axial: Flag indicating whether to compute gradients in the + direction of the axis of the cylinder. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 3: + raise ValueError("Invalid input shape; input must have at least three axes.") + if axes is not None and len(axes) != 3: + raise ValueError("Invalid axes specified; exactly three axes must be specified.") + if not angular and not radial and not axial: + raise ValueError("At least one of angular, radial, and axial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1, 2) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + center = center.at[-1].set(0) # type: ignore + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] + g0 = g0[..., np.newaxis] + g1 = g1[..., np.newaxis] + theta = snp.arctan2(g0, g1) + # Re-order theta axes in case indices in axes parameter are not in increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + if len(input_shape) > 3: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + coord = [] + if angular: + coord.append( + snp.blockarray( + [-snp.cos(theta), snp.sin(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if radial: + coord.append( + snp.blockarray( + [snp.sin(theta), snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if axial: + coord.append( + snp.blockarray( + [ + snp.array([0.0], dtype=real_input_dtype), + snp.array([0.0], dtype=real_input_dtype), + snp.array([1.0], dtype=real_input_dtype), + ] + ) + ) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + cdiff=cdiff, + coord=coord, + jit=jit, + ) + + +class SphericalGradient(ProjectedGradient): + """Gradient projected into spherical coordinates. + + Compute gradients projected onto spherical coordinate axes, based on + the approach described in :cite:`hossein-2024-total`. The local + coordinate axes are illustrated in the figure below. + + .. plot:: figures/spheregrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `azimuthal`, `polar`, and `radial` is ``True``, the + operator output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + azimuthal: bool = True, + polar: bool = True, + radial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, + :math:`i_y` and :math:`i_z` are input array axes assigned + to :math:`x`, :math:`y`, and :math:`z` coordinates + respectively. Defaults to ``None``, in which case the + axes are taken to be `(0, 1, 2)`. If an integer, this + operator returns a :class:`jax.Array`. If a tuple or + ``None``, the resulting arrays are stacked into a + :class:`.BlockArray`. + center: Center of the spherical coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the input array. + azimuthal: Flag indicating whether to compute gradients in + the azimuthal direction. + polar: Flag indicating whether to compute gradients in the + polar direction. + radial: Flag indicating whether to compute gradients in the + radial direction. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 3: + raise ValueError("Invalid input shape; input must have at least three axes.") + if axes is not None and len(axes) != 3: + raise ValueError("Invalid axes specified; exactly three axes must be specified.") + if not azimuthal and not polar and not radial: + raise ValueError("At least one of azimuthal, polar, and radial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1, 2) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1, g2 = snp.ogrid[-center[0] : end[0], -center[1] : end[1], -center[2] : end[2]] + theta = snp.arctan2(g1, g0) + phi = snp.arctan2(snp.sqrt(g0**2 + g1**2), g2) + # Re-order theta and phi axes in case indices in axes parameter are not in + # increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + phi = snp.transpose(phi, axis_order) + if len(input_shape) > 3: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + phi = snp.expand_dims(phi, single) + coord = [] + if azimuthal: + coord.append( + snp.blockarray( + [snp.sin(theta), -snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if polar: + coord.append( + snp.blockarray( + [snp.cos(phi) * snp.cos(theta), snp.cos(phi) * snp.sin(theta), -snp.sin(phi)] + ) + ) + if radial: + coord.append( + snp.blockarray( + [snp.sin(phi) * snp.cos(theta), snp.sin(phi) * snp.sin(theta), snp.cos(phi)] + ) + ) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + coord=coord, + cdiff=cdiff, + jit=jit, + ) diff --git a/scico/test/linop/test_grad.py b/scico/test/linop/test_grad.py new file mode 100644 index 000000000..19687519c --- /dev/null +++ b/scico/test/linop/test_grad.py @@ -0,0 +1,218 @@ +from itertools import combinations + +import numpy as np + +import jax + +import pytest + +import scico.numpy as snp +from scico.linop import CylindricalGradient, PolarGradient, SphericalGradient +from scico.numpy import Array, BlockArray +from scico.random import randn + + +class TestPolarGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("outflags", [(True, True), (True, False), (False, True)]) + @pytest.mark.parametrize("center", [None, (-2, 3), (1.2, -3.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20), None), + ((20, 21), (0, 1)), + ((16, 17, 3), (0, 1)), + ((2, 17, 16), (1, 2)), + ((2, 17, 16, 3), (2, 1)), + ], + ) + @pytest.mark.parametrize("cdiff", [True, False]) + def test_eval(self, cdiff, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + angular, radial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = PolarGradient( + input_shape, + axes=axes, + center=center, + angular=angular, + radial=radial, + cdiff=cdiff, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + if angular and radial: + assert isinstance(Ax, BlockArray) + assert len(Ax.shape) == 2 + assert Ax[0].shape == input_shape + assert Ax[1].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + assert snp.abs(snp.sum(c0 * c1)) < 1e-5 + + +class TestCylindricalGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize( + "outflags", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + ], + ) + @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20, 20), None), + ((17, 18, 19), (0, 1, 2)), + ((16, 17, 18, 3), (0, 1, 2)), + ((2, 17, 16, 15), (1, 2, 3)), + ((17, 2, 16, 15), (0, 2, 3)), + ((17, 2, 16, 15), (3, 2, 0)), + ], + ) + def test_eval(self, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1, 2) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + angular, radial, axial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = CylindricalGradient( + input_shape, + axes=axes, + center=center, + angular=angular, + radial=radial, + axial=axial, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + Nc = sum([angular, radial, axial]) + if Nc > 1: + assert isinstance(Ax, BlockArray) + assert len(Ax) == Nc + for n in range(Nc): + assert Ax[n].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() + assert snp.abs(s) < 1e-5 + + +class TestSphericalGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize( + "outflags", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + ], + ) + @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20, 20), None), + ((17, 18, 19), (0, 1, 2)), + ((16, 17, 18, 3), (0, 1, 2)), + ((2, 17, 16, 15), (1, 2, 3)), + ((17, 2, 16, 15), (0, 2, 3)), + ((17, 2, 16, 15), (3, 2, 0)), + ], + ) + def test_eval(self, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1, 2) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + azimuthal, polar, radial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = SphericalGradient( + input_shape, + axes=axes, + center=center, + azimuthal=azimuthal, + polar=polar, + radial=radial, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + Nc = sum([azimuthal, polar, radial]) + if Nc > 1: + assert isinstance(Ax, BlockArray) + assert len(Ax) == Nc + for n in range(Nc): + assert Ax[n].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() + assert snp.abs(s) < 1e-5