Skip to content

Commit

Permalink
some unsaved changes from last commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Salman Naqvi committed Oct 5, 2023
1 parent 4947382 commit 4375ddf
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ class TV2DNorm(Functional):
has_eval = True
has_prox = True

def __init__(self, dims: Tuple[int, int] = (1,1), tau: float = 1.0):
def __init__(self, dims: Tuple[int, int] = (1, 1), tau: float = 1.0):
r"""
Args:
tau: Parameter :math:`\tau` in the norm definition.
Expand Down Expand Up @@ -543,15 +543,14 @@ def prox(
y = y.at[:].add(
self.iht2(
self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False
)
)
)
y = y.at[:].add(
self.iht2(
self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True
)
)
)
y = y.at[:].divide(K)

return y

def ht2(self, x, axis, shift):
Expand All @@ -569,7 +568,6 @@ def ht2(self, x, axis, shift):
else:
w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2]))
w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2]))

return w

def iht2(self, w, axis, shift):
Expand All @@ -587,12 +585,10 @@ def iht2(self, w, axis, shift):

if shift:
y = snp.roll(y, 1, axis)

return y

def shrink(self, x, tau):
r"""Wavelet shrinkage operator"""
threshed = snp.maximum(snp.abs(x) - tau, 0)
threshed = threshed.at[:].multiply(snp.sign(x))
return threshed

return threshed

0 comments on commit 4375ddf

Please sign in to comment.