Skip to content

Commit

Permalink
Added a new Functional for TV Norm (#456)
Browse files Browse the repository at this point in the history
* Added a new Functional for TV Norm implementing its proximal operator using the fast subiteration free algorithm proposed by Kamilov, 2016

* added checks for input shape in TV2DNorm

* fixed lint errors, changed required argument to default in TV2DNorm, fixed inconsistent signature for prox function, added more comments to the helper functions

* some unsaved changes from last commit

* newline at end of file error

* sort imports lint error

* removed the default shape parameter from TV2DNorm

* Some docs edits

* Disable BlockArray tests on TV2DNorm

* Fix black formatting

* updated the TV norm logic to apply shrinkage to only the difference operator of the haar transform as in Kamilov, 2016

* Implementation supporting arbitrary dimensional inputs

* Add a test

* Minor changes

* New implementation of TV norm and approximage prox

* Clean up

* Typo fix

* Minor change

* Add change log entry

* Resolve typing errors

* Resolve some oversights and issues arising when 64 bit floats enabled

* Apply skipped pre-commit

---------

Co-authored-by: Salman Naqvi <[email protected]>
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
3 people authored Nov 5, 2023
1 parent 27e2aec commit 08a5896
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• New functional ``functional.AnisotropicTVNorm`` with proximal operator
approximation.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
Expand Down
12 changes: 12 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}

@Article {kamilov-2016-parallel,
title = {A parallel proximal algorithm for anisotropic total
variation minimization},
author = {Ulugbek S. Kamilov},
journal = {IEEE Transactions on Image Processing},
volume = 26,
number = 2,
pages = {539--548},
year = 2016,
doi = {10.1109/tip.2016.2629449 }
}

@Article {kamilov-2017-plugandplay,
author = {Ulugbek Kamilov and Hassan Mansour and Brendt
Wohlberg},
Expand Down
2 changes: 2 additions & 0 deletions scico/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
NuclearNorm,
L1MinusL2Norm,
)
from ._tvnorm import AnisotropicTVNorm
from ._indicator import NonNegativeIndicator, L2BallIndicator
from ._denoiser import BM3D, BM4D, DnCNN
from ._dist import SetDistance, SquaredSetDistance


__all__ = [
"AnisotropicTVNorm",
"Functional",
"ScaledFunctional",
"SeparableFunctional",
Expand Down
142 changes: 142 additions & 0 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Anisotropic total variation norm."""

from typing import Optional, Tuple

from scico import numpy as snp
from scico.linop import (
CircularConvolve,
FiniteDifference,
LinearOperator,
VerticalStack,
)
from scico.numpy import Array

from ._functional import Functional
from ._norm import L1Norm


class AnisotropicTVNorm(Functional):
r"""The anisotropic total variation (TV) norm.
The anisotropic total variation (TV) norm computed by
.. code-block:: python
ATV = scico.functional.AnisotropicTVNorm()
x_norm = ATV(x)
is equivalent to
.. code-block:: python
C = linop.FiniteDifference(input_shape=x.shape, circular=True)
L1 = functional.L1Norm()
x_norm = L1(C @ x)
The scaled proximal operator is computed using an approximation that
holds for small scaling parameters :cite:`kamilov-2016-parallel`.
This does not imply that it can only be applied to problems requiring
a small regularization parameter since most proximal algorithms
include an additional algorithm parameter that also plays a role in
the parameter of the proximal operator. For example, in :class:`.PGM`
and :class:`.AcceleratedPGM`, the scaled proximal operator parameter
is the regularization parameter divided by the `L0` algorithm
parameter, and for :class:`.ADMM`, the scaled proximal operator
parameters are the regularization parameters divided by the entries
in the `rho_list` algorithm parameter.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
r"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
self.l1norm = L1Norm()
self.G: Optional[LinearOperator] = None
self.W: Optional[LinearOperator] = None

def __call__(self, x: Array) -> float:
r"""Compute the anisotropic TV norm of an array."""
if self.G is None or self.G.shape[1] != x.shape:
if self.ndims is None:
ndims = x.ndim
else:
ndims = self.ndims
axes = tuple(range(ndims))
self.G = FiniteDifference(
x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True
)
return snp.sum(snp.abs(self.G @ x))

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.
Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
r"""Approximate proximal operator of the isotropic TV norm.
Approximation of the proximal operator of the anisotropic TV norm,
computed via the method described in :cite:`kamilov-2016-parallel`.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lam`.
kwargs: Additional arguments that may be used by derived
classes.
"""
if self.ndims is None:
ndims = v.ndim
else:
ndims = self.ndims
K = 2 * ndims

if self.W is None or self.W.shape[1] != v.shape:
h0 = self.h0.astype(v.dtype)
h1 = self.h1.astype(v.dtype)
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
h0.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
h1.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
self.W = VerticalStack([C0, C1], jit=True)

Wv = self.W @ v
# Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform
Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam))
return (1.0 / K) * self.W.T @ Wv
7 changes: 6 additions & 1 deletion scico/test/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from scico import functional
from scico.random import randn

NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm]
NO_BLOCK_ARRAY = [
functional.L21Norm,
functional.L1MinusL2Norm,
functional.NuclearNorm,
functional.AnisotropicTVNorm,
]
NO_COMPLEX = [functional.NonNegativeIndicator]


Expand Down
40 changes: 40 additions & 0 deletions scico/test/functional/test_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np

import scico.random
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.optimize.pgm import AcceleratedPGM


def test_tvnorm():

N = 128
g = np.linspace(0, 2 * np.pi, N, dtype=np.float32)
x_gt = np.sin(2 * g)
x_gt[x_gt > 0.5] = 0.5
x_gt[x_gt < -0.5] = -0.5
σ = 0.02
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

λ = 5e-2
f = loss.SquaredL2Loss(y=y)

C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True)
g = λ * functional.L1Norm()
solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[1e1],
x0=y,
maxiter=50,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
)
x_tvdn = solver.solve()

h = λ * functional.AnisotropicTVNorm()
solver = AcceleratedPGM(f=f, g=h, L0=2e2, x0=y, maxiter=50)
x_approx = solver.solve()

assert metric.snr(x_tvdn, x_approx) > 45

0 comments on commit 08a5896

Please sign in to comment.