-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a new Functional for TV Norm (#456)
* Added a new Functional for TV Norm implementing its proximal operator using the fast subiteration free algorithm proposed by Kamilov, 2016 * added checks for input shape in TV2DNorm * fixed lint errors, changed required argument to default in TV2DNorm, fixed inconsistent signature for prox function, added more comments to the helper functions * some unsaved changes from last commit * newline at end of file error * sort imports lint error * removed the default shape parameter from TV2DNorm * Some docs edits * Disable BlockArray tests on TV2DNorm * Fix black formatting * updated the TV norm logic to apply shrinkage to only the difference operator of the haar transform as in Kamilov, 2016 * Implementation supporting arbitrary dimensional inputs * Add a test * Minor changes * New implementation of TV norm and approximage prox * Clean up * Typo fix * Minor change * Add change log entry * Resolve typing errors * Resolve some oversights and issues arising when 64 bit floats enabled * Apply skipped pre-commit --------- Co-authored-by: Salman Naqvi <[email protected]> Co-authored-by: Brendt Wohlberg <[email protected]>
- Loading branch information
1 parent
27e2aec
commit 08a5896
Showing
6 changed files
with
204 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2023 by SCICO Developers | ||
# All rights reserved. BSD 3-clause License. | ||
# This file is part of the SCICO package. Details of the copyright and | ||
# user license can be found in the 'LICENSE' file distributed with the | ||
# package. | ||
|
||
"""Anisotropic total variation norm.""" | ||
|
||
from typing import Optional, Tuple | ||
|
||
from scico import numpy as snp | ||
from scico.linop import ( | ||
CircularConvolve, | ||
FiniteDifference, | ||
LinearOperator, | ||
VerticalStack, | ||
) | ||
from scico.numpy import Array | ||
|
||
from ._functional import Functional | ||
from ._norm import L1Norm | ||
|
||
|
||
class AnisotropicTVNorm(Functional): | ||
r"""The anisotropic total variation (TV) norm. | ||
The anisotropic total variation (TV) norm computed by | ||
.. code-block:: python | ||
ATV = scico.functional.AnisotropicTVNorm() | ||
x_norm = ATV(x) | ||
is equivalent to | ||
.. code-block:: python | ||
C = linop.FiniteDifference(input_shape=x.shape, circular=True) | ||
L1 = functional.L1Norm() | ||
x_norm = L1(C @ x) | ||
The scaled proximal operator is computed using an approximation that | ||
holds for small scaling parameters :cite:`kamilov-2016-parallel`. | ||
This does not imply that it can only be applied to problems requiring | ||
a small regularization parameter since most proximal algorithms | ||
include an additional algorithm parameter that also plays a role in | ||
the parameter of the proximal operator. For example, in :class:`.PGM` | ||
and :class:`.AcceleratedPGM`, the scaled proximal operator parameter | ||
is the regularization parameter divided by the `L0` algorithm | ||
parameter, and for :class:`.ADMM`, the scaled proximal operator | ||
parameters are the regularization parameters divided by the entries | ||
in the `rho_list` algorithm parameter. | ||
""" | ||
|
||
has_eval = True | ||
has_prox = True | ||
|
||
def __init__(self, ndims: Optional[int] = None): | ||
r""" | ||
Args: | ||
ndims: Number of (trailing) dimensions of the input over | ||
which to apply the finite difference operator. If | ||
``None``, differences are evaluated along all axes. | ||
""" | ||
self.ndims = ndims | ||
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter | ||
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter | ||
self.l1norm = L1Norm() | ||
self.G: Optional[LinearOperator] = None | ||
self.W: Optional[LinearOperator] = None | ||
|
||
def __call__(self, x: Array) -> float: | ||
r"""Compute the anisotropic TV norm of an array.""" | ||
if self.G is None or self.G.shape[1] != x.shape: | ||
if self.ndims is None: | ||
ndims = x.ndim | ||
else: | ||
ndims = self.ndims | ||
axes = tuple(range(ndims)) | ||
self.G = FiniteDifference( | ||
x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True | ||
) | ||
return snp.sum(snp.abs(self.G @ x)) | ||
|
||
@staticmethod | ||
def _shape(idx: int, ndims: int) -> Tuple: | ||
"""Construct a shape tuple. | ||
Construct a tuple of size `ndims` with all unit entries except | ||
for index `idx`, which has a -1 entry. | ||
""" | ||
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) | ||
|
||
def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: | ||
r"""Approximate proximal operator of the isotropic TV norm. | ||
Approximation of the proximal operator of the anisotropic TV norm, | ||
computed via the method described in :cite:`kamilov-2016-parallel`. | ||
Args: | ||
v: Input array :math:`\mb{v}`. | ||
lam: Proximal parameter :math:`\lam`. | ||
kwargs: Additional arguments that may be used by derived | ||
classes. | ||
""" | ||
if self.ndims is None: | ||
ndims = v.ndim | ||
else: | ||
ndims = self.ndims | ||
K = 2 * ndims | ||
|
||
if self.W is None or self.W.shape[1] != v.shape: | ||
h0 = self.h0.astype(v.dtype) | ||
h1 = self.h1.astype(v.dtype) | ||
C0 = VerticalStack( # Stack of lowpass filter operators for each axis | ||
[ | ||
CircularConvolve( | ||
h0.reshape(AnisotropicTVNorm._shape(k, ndims)), | ||
v.shape, | ||
ndims=self.ndims, | ||
) | ||
for k in range(ndims) | ||
] | ||
) | ||
C1 = VerticalStack( # Stack of highpass filter operators for each axis | ||
[ | ||
CircularConvolve( | ||
h1.reshape(AnisotropicTVNorm._shape(k, ndims)), | ||
v.shape, | ||
ndims=self.ndims, | ||
) | ||
for k in range(ndims) | ||
] | ||
) | ||
# single-level shift-invariant Haar transform | ||
self.W = VerticalStack([C0, C1], jit=True) | ||
|
||
Wv = self.W @ v | ||
# Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform | ||
Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam)) | ||
return (1.0 / K) * self.W.T @ Wv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
|
||
import scico.random | ||
from scico import functional, linop, loss, metric | ||
from scico.optimize.admm import ADMM, LinearSubproblemSolver | ||
from scico.optimize.pgm import AcceleratedPGM | ||
|
||
|
||
def test_tvnorm(): | ||
|
||
N = 128 | ||
g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) | ||
x_gt = np.sin(2 * g) | ||
x_gt[x_gt > 0.5] = 0.5 | ||
x_gt[x_gt < -0.5] = -0.5 | ||
σ = 0.02 | ||
noise, key = scico.random.randn(x_gt.shape, seed=0) | ||
y = x_gt + σ * noise | ||
|
||
λ = 5e-2 | ||
f = loss.SquaredL2Loss(y=y) | ||
|
||
C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) | ||
g = λ * functional.L1Norm() | ||
solver = ADMM( | ||
f=f, | ||
g_list=[g], | ||
C_list=[C], | ||
rho_list=[1e1], | ||
x0=y, | ||
maxiter=50, | ||
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), | ||
) | ||
x_tvdn = solver.solve() | ||
|
||
h = λ * functional.AnisotropicTVNorm() | ||
solver = AcceleratedPGM(f=f, g=h, L0=2e2, x0=y, maxiter=50) | ||
x_approx = solver.solve() | ||
|
||
assert metric.snr(x_tvdn, x_approx) > 45 |