From 4375ddfe04b0c42619a2d4540385f776b46ab17b Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 16:34:36 -0700 Subject: [PATCH] some unsaved changes from last commit --- scico/functional/_norm.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 61443f6b9..352377d26 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -501,7 +501,7 @@ class TV2DNorm(Functional): has_eval = True has_prox = True - def __init__(self, dims: Tuple[int, int] = (1,1), tau: float = 1.0): + def __init__(self, dims: Tuple[int, int] = (1, 1), tau: float = 1.0): r""" Args: tau: Parameter :math:`\tau` in the norm definition. @@ -543,15 +543,14 @@ def prox( y = y.at[:].add( self.iht2( self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False - ) ) + ) y = y.at[:].add( self.iht2( self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True - ) ) + ) y = y.at[:].divide(K) - return y def ht2(self, x, axis, shift): @@ -569,7 +568,6 @@ def ht2(self, x, axis, shift): else: w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) - return w def iht2(self, w, axis, shift): @@ -587,12 +585,10 @@ def iht2(self, w, axis, shift): if shift: y = snp.roll(y, 1, axis) - return y def shrink(self, x, tau): r"""Wavelet shrinkage operator""" threshed = snp.maximum(snp.abs(x) - tau, 0) threshed = threshed.at[:].multiply(snp.sign(x)) - return threshed - \ No newline at end of file + return threshed \ No newline at end of file