Skip to content

Commit

Permalink
ENH: simplify and optimize some proximal operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Nov 11, 2017
1 parent 74c6a05 commit 5ed19cc
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 83 deletions.
18 changes: 4 additions & 14 deletions odl/solvers/functional/default_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,20 +864,10 @@ def __init__(self, space, lower=None, upper=None):

def _call(self, x):
"""Apply the functional to the given point."""
# Compute the projection of x onto the box, if this is equal to x we
# know x is inside the box.
tmp = self.domain.element()
if self.lower is not None and self.upper is None:
x.ufuncs.maximum(self.lower, out=tmp)
elif self.lower is None and self.upper is not None:
x.ufuncs.minimum(self.upper, out=tmp)
elif self.lower is not None and self.upper is not None:
x.ufuncs.maximum(self.lower, out=tmp)
tmp.ufuncs.minimum(self.upper, out=tmp)
else:
tmp.assign(x)

return np.inf if x.dist(tmp) > 0 else 0
# Since the proximal projects onto our feasible set we can simply
# check if it changes anything
proj = self.proximal(1)(x)
return np.inf if x.dist(proj) > 0 else 0

@property
def proximal(self):
Expand Down
199 changes: 131 additions & 68 deletions odl/solvers/nonsmooth/proximal_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from __future__ import print_function, division, absolute_import
import numpy as np

from odl.operator import (Operator, IdentityOperator, ScalingOperator,
ConstantOperator, DiagonalOperator, PointwiseNorm)
from odl.operator import (
Operator, IdentityOperator, ScalingOperator, ConstantOperator,
DiagonalOperator, PointwiseNorm)
from odl.space import ProductSpace
from odl.set import LinearSpaceElement
from odl.util import cache_arguments
Expand Down Expand Up @@ -201,7 +202,6 @@ def translation_prox_factory(sigma):
The proximal operator of ``s * F( . - y)`` where ``s`` is the
step size
"""

return (ConstantOperator(y) + prox_factory(sigma) *
(IdentityOperator(y.space) - ConstantOperator(y)))

Expand Down Expand Up @@ -478,7 +478,6 @@ def identity_factory(sigma):
The proximal operator instance of G = 0 which is the
identity operator
"""

return IdentityOperator(space)

return identity_factory
Expand Down Expand Up @@ -571,7 +570,6 @@ def __init__(self, sigma):

def _call(self, x, out):
"""Apply the operator to ``x`` and store the result in ``out``."""

if lower is not None and upper is None:
x.ufuncs.maximum(lower, out=out)
elif lower is None and upper is not None:
Expand Down Expand Up @@ -605,7 +603,6 @@ def proximal_nonnegativity(space):
--------
proximal_box_constraint
"""

return proximal_box_constraint(space, lower=0)


Expand Down Expand Up @@ -747,7 +744,6 @@ def __init__(self, sigma):

def _call(self, x, out):
"""Apply the operator to ``x`` and stores the result in ``out``."""

dtype = getattr(self.domain, 'dtype', float)
eps = np.finfo(dtype).resolution * 10

Expand Down Expand Up @@ -850,10 +846,8 @@ def __init__(self, sigma):
self.sigma = float(sigma)

def _call(self, x, out):
"""Apply the operator to ``x`` and stores the result in ``out``"""

"""Apply the operator to ``x`` and store the result in ``out``"""
# (x - sig*g) / (1 + sig/(2 lam))

sig = self.sigma
if g is None:
out.lincomb(1.0 / (1 + 0.5 * sig / lam), x)
Expand Down Expand Up @@ -908,9 +902,34 @@ def proximal_l2_squared(space, lam=1, g=None):
proximal_l2 : proximal without square
proximal_convex_conj_l2_squared : proximal for convex conjugate
"""
# TODO: optimize
prox_cc_l2_squared = proximal_convex_conj_l2_squared(space, lam=lam, g=g)
return proximal_convex_conj(prox_cc_l2_squared)

class ProximalL2Squared(Operator):

"""Proximal operator of the squared l2-norm/dist."""

def __init__(self, sigma):
"""Initialize a new instance.
Parameters
----------
sigma : positive float
Step size parameter
"""
super(ProximalL2Squared, self).__init__(
domain=space, range=space, linear=g is None)
self.sigma = float(sigma)

def _call(self, x, out):
"""Apply the operator to ``x`` and store the result in ``out``"""
# (x + 2*sig*lam*g) / (1 + 2*sig*lam))
sig = self.sigma
if g is None:
out.lincomb(1.0 / (1 + 2 * sig * lam), x)
else:
out.lincomb(1.0 / (1 + 2 * sig * lam), x,
2 * sig * lam / (1 + 2 * sig * lam), g)

return ProximalL2Squared


def proximal_convex_conj_l1(space, lam=1, g=None, isotropic=False):
Expand Down Expand Up @@ -1027,11 +1046,12 @@ def __init__(self, sigma):

def _call(self, x, out):
"""Apply the operator to ``x`` and store the result in ``out``."""
# lam * (x - sig * g) / max(lam, |x - sig * g|)

# lam * (x - sigma * g) / max(lam, |x - sigma * g|)

# diff = x - sig * g
if g is not None:
diff = x - self.sigma * g
diff = self.domain.element()
diff.lincomb(1, x, -self.sigma, g)
else:
if x is out:
# Handle aliased data properly
Expand All @@ -1040,36 +1060,23 @@ def _call(self, x, out):
diff = x

if isotropic:
# Calculate |x| = pointwise 2-norm of x

tmp = diff[0] ** 2
sq_tmp = x[0].space.element()
for x_i in diff[1:]:
x_i.multiply(x_i, out=sq_tmp)
tmp += sq_tmp
tmp.ufuncs.sqrt(out=tmp)

# Pointwise maximum of |x| and lambda
tmp.ufuncs.maximum(lam, out=tmp)

# Global scaling
tmp /= lam
# denom = max( |x-sig*g|_2, lam ) / lam (|.|_2 pointwise)
pwnorm = PointwiseNorm(self.domain, exponent=2)
denom = pwnorm(diff)
denom.ufuncs.maximum(lam, out=denom)
denom /= lam

# Pointwise division
for out_i, x_i in zip(out, diff):
x_i.divide(tmp, out=out_i)
for out_i, diff_i in zip(out, diff):
diff_i.divide(denom, out=out_i)

else:
# Calculate |x| = pointwise 2-norm of x
# out = max( |x-sig*g|, lam ) / lam
diff.ufuncs.absolute(out=out)

# Pointwise maximum of |x| and lambda
out.ufuncs.maximum(lam, out=out)

# Global scaling
out /= lam

# Pointwise division
# out = diff / ...
diff.divide(out, out=out)

return ProximalConvexConjL1
Expand Down Expand Up @@ -1111,27 +1118,32 @@ def proximal_l1(space, lam=1, g=None, isotropic=False):
F(x) = \\lambda \|x - g\|_1.
For a step size :math:`\\sigma`, the proximal operator of :math:`\\sigma F`
is
is the "soft-shrinkage" operator
.. math::
\mathrm{prox}_{\\sigma F}(y) = \\begin{cases}
y - \\sigma \\lambda
& \\text{if } y > g + \\sigma \\lambda, \\\\
0
& \\text{if } g - \\sigma \\lambda \\leq y \\leq g +
\\sigma \\lambda \\\\
y + \\sigma \\lambda
& \\text{if } y < g - \\sigma \\lambda,
\mathrm{prox}_{\\sigma F}(x) =
\\begin{cases}
g, & \\text{where } |x - g| \\leq \sigma\\lambda, \\\\
x - \sigma\\lambda \mathrm{sign}(x - g), & \\text{elsewhere.}
\\end{cases}
Here, all operations are to be read pointwise.
An alternative formulation is available for `ProductSpace`'s, where the
the ``isotropic`` parameter can be used, giving
the ``isotropic`` parameter can be used, i.e.,
.. math::
F(x) = \\lambda \| \|x - g\|_2 \|_1
F(x) = \\lambda \| |x - g|_2 \|_1
The proximal can be calculated using the Moreau equality (also known as
Moreau decomposition or Moreau identity). See for example [BC2011].
with the pointwise Euclidean norm :math:`|\cdot|_2`. For this case, one
gets
.. math::
\mathrm{prox}_{\\sigma F}(x) =
\\begin{cases}
g, & \\text{where } |x - g|_2 \\leq \sigma\\lambda, \\\\
x - \sigma\\lambda \\frac{x - g}{|x - g|_2}, & \\text{elsewhere.}
\\end{cases}
See Also
--------
Expand All @@ -1142,10 +1154,68 @@ def proximal_l1(space, lam=1, g=None, isotropic=False):
[BC2011] Bauschke, H H, and Combettes, P L. *Convex analysis and
monotone operator theory in Hilbert spaces*. Springer, 2011.
"""
# TODO: optimize
prox_cc_l1 = proximal_convex_conj_l1(space, lam=lam, g=g,
isotropic=isotropic)
return proximal_convex_conj(prox_cc_l1)
lam = float(lam)

if g is not None and g not in space:
raise TypeError('{!r} is not an element of {!r}'.format(g, space))

class ProximalL1(Operator):

"""Proximal operator of the l1-norm/distance."""

def __init__(self, sigma):
"""Initialize a new instance.
Parameters
----------
sigma : positive float
Step size parameter
"""
super(ProximalL1, self).__init__(
domain=space, range=space, linear=False)
self.sigma = float(sigma)

def _call(self, x, out):
"""Apply the operator to ``x`` and stores the result in ``out``."""
# diff = x - g
if g is not None:
diff = x - g
else:
if x is out:
# Handle aliased data properly
diff = x.copy()
else:
diff = x

if isotropic:
# We write the operator as
# x - (x - g) / max(|x - g|_2 / sig*lam, 1)
pwnorm = PointwiseNorm(self.domain, exponent=2)
denom = pwnorm(diff)
denom /= self.sigma * lam
denom.ufuncs.maximum(1, out=denom)

# out = (x - g) / denom
for out_i, diff_i in zip(out, diff):
diff_i.divide(denom, out=out_i)

# out = x - ...
out.lincomb(1, x, -1, out)

else:
# We write the operator as
# x - (x - g) / max(|x - g| / sig*lam, 1)
denom = diff.ufuncs.absolute()
denom /= self.sigma * lam
denom.ufuncs.maximum(1, out=denom)

# out = (x - g) / denom
diff.ufuncs.divide(denom, out=out)

# out = x - ...
out.lincomb(1, x, -1, out)

return ProximalL1


def proximal_convex_conj_kl(space, lam=1, g=None):
Expand Down Expand Up @@ -1251,33 +1321,26 @@ def __init__(self, sigma):

def _call(self, x, out):
"""Apply the operator to ``x`` and stores the result in ``out``."""
# (x + lam - sqrt((x - lam)^2 + 4*lam*sig*g)) / 2

# 1 / 2 (lam_X + x - sqrt((x - lam_X) ^ 2 + 4; lam sigma g)

# out = x - lam_X
# out = (x - lam)^2
out.assign(x)
out -= lam

# (out)^2
out.ufuncs.square(out=out)

# out = out + 4 lam sigma g
# out = ... + 4*lam*sigma*g
# If g is None, it is taken as the one element
if g is None:
out += 4.0 * lam * self.sigma
else:
out.lincomb(1, out, 4.0 * lam * self.sigma, g)

# out = sqrt(out)
# out = x - sqrt(...) + lam
out.ufuncs.sqrt(out=out)

# out = x - out
out.lincomb(1, x, -1, out)
out += lam

# out = lam_X + out
out.lincomb(lam, space.one(), 1, out)

# out = 1/2 * out
# out = 1/2 * ...
out /= 2

return ProximalConvexConjKL
Expand Down
1 change: 0 additions & 1 deletion odl/test/solvers/nonsmooth/proximal_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from __future__ import division
import numpy as np
import pytest
import scipy.special

import odl
Expand Down

0 comments on commit 5ed19cc

Please sign in to comment.