From 8826fc401cc277be72862c06824507d0d808e248 Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Sun, 22 Oct 2017 00:33:24 +0200 Subject: [PATCH] ENH: simplify and optimize some proximal operators --- odl/solvers/functional/default_functionals.py | 18 +- odl/solvers/nonsmooth/proximal_operators.py | 199 ++++++++++++------ .../nonsmooth/proximal_operator_test.py | 1 - 3 files changed, 135 insertions(+), 83 deletions(-) diff --git a/odl/solvers/functional/default_functionals.py b/odl/solvers/functional/default_functionals.py index d23b43a387d..3568f066e14 100644 --- a/odl/solvers/functional/default_functionals.py +++ b/odl/solvers/functional/default_functionals.py @@ -865,20 +865,10 @@ def __init__(self, space, lower=None, upper=None): def _call(self, x): """Apply the functional to the given point.""" - # Compute the projection of x onto the box, if this is equal to x we - # know x is inside the box. - tmp = self.domain.element() - if self.lower is not None and self.upper is None: - x.ufuncs.maximum(self.lower, out=tmp) - elif self.lower is None and self.upper is not None: - x.ufuncs.minimum(self.upper, out=tmp) - elif self.lower is not None and self.upper is not None: - x.ufuncs.maximum(self.lower, out=tmp) - tmp.ufuncs.minimum(self.upper, out=tmp) - else: - tmp.assign(x) - - return np.inf if x.dist(tmp) > 0 else 0 + # Since the proximal projects onto our feasible set we can simply + # check if it changes anything + proj = self.proximal(1)(x) + return np.inf if x.dist(proj) > 0 else 0 @property def proximal(self): diff --git a/odl/solvers/nonsmooth/proximal_operators.py b/odl/solvers/nonsmooth/proximal_operators.py index 0e1c6258950..b96ad17e855 100644 --- a/odl/solvers/nonsmooth/proximal_operators.py +++ b/odl/solvers/nonsmooth/proximal_operators.py @@ -26,8 +26,9 @@ import numpy as np -from odl.operator import (Operator, IdentityOperator, ScalingOperator, - ConstantOperator, DiagonalOperator) +from odl.operator import ( + Operator, IdentityOperator, ScalingOperator, ConstantOperator, + DiagonalOperator, PointwiseNorm) from odl.space import ProductSpace from odl.set import LinearSpaceElement from odl.util import cache_arguments @@ -202,7 +203,6 @@ def translation_prox_factory(sigma): The proximal operator of ``s * F( . - y)`` where ``s`` is the step size """ - return (ConstantOperator(y) + prox_factory(sigma) * (IdentityOperator(y.space) - ConstantOperator(y))) @@ -480,7 +480,6 @@ def identity_factory(sigma): The proximal operator instance of G = 0 which is the identity operator """ - return IdentityOperator(space) return identity_factory @@ -573,7 +572,6 @@ def __init__(self, sigma): def _call(self, x, out): """Apply the operator to ``x`` and store the result in ``out``.""" - if lower is not None and upper is None: x.ufuncs.maximum(lower, out=out) elif lower is None and upper is not None: @@ -607,7 +605,6 @@ def proximal_nonnegativity(space): -------- proximal_box_constraint """ - return proximal_box_constraint(space, lower=0) @@ -749,7 +746,6 @@ def __init__(self, sigma): def _call(self, x, out): """Apply the operator to ``x`` and stores the result in ``out``.""" - dtype = getattr(self.domain, 'dtype', float) eps = np.finfo(dtype).resolution * 10 @@ -852,10 +848,8 @@ def __init__(self, sigma): self.sigma = float(sigma) def _call(self, x, out): - """Apply the operator to ``x`` and stores the result in ``out``""" - + """Apply the operator to ``x`` and store the result in ``out``""" # (x - sig*g) / (1 + sig/(2 lam)) - sig = self.sigma if g is None: out.lincomb(1.0 / (1 + 0.5 * sig / lam), x) @@ -910,9 +904,34 @@ def proximal_l2_squared(space, lam=1, g=None): proximal_l2 : proximal without square proximal_convex_conj_l2_squared : proximal for convex conjugate """ - # TODO: optimize - prox_cc_l2_squared = proximal_convex_conj_l2_squared(space, lam=lam, g=g) - return proximal_convex_conj(prox_cc_l2_squared) + + class ProximalL2Squared(Operator): + + """Proximal operator of the squared l2-norm/dist.""" + + def __init__(self, sigma): + """Initialize a new instance. + + Parameters + ---------- + sigma : positive float + Step size parameter + """ + super(ProximalL2Squared, self).__init__( + domain=space, range=space, linear=g is None) + self.sigma = float(sigma) + + def _call(self, x, out): + """Apply the operator to ``x`` and store the result in ``out``""" + # (x + 2*sig*lam*g) / (1 + 2*sig*lam)) + sig = self.sigma + if g is None: + out.lincomb(1.0 / (1 + 2 * sig * lam), x) + else: + out.lincomb(1.0 / (1 + 2 * sig * lam), x, + 2 * sig * lam / (1 + 2 * sig * lam), g) + + return ProximalL2Squared def proximal_convex_conj_l1(space, lam=1, g=None, isotropic=False): @@ -1029,11 +1048,12 @@ def __init__(self, sigma): def _call(self, x, out): """Apply the operator to ``x`` and store the result in ``out``.""" + # lam * (x - sig * g) / max(lam, |x - sig * g|) - # lam * (x - sigma * g) / max(lam, |x - sigma * g|) - + # diff = x - sig * g if g is not None: - diff = x - self.sigma * g + diff = self.domain.element() + diff.lincomb(1, x, -self.sigma, g) else: if x is out: # Handle aliased data properly @@ -1042,36 +1062,23 @@ def _call(self, x, out): diff = x if isotropic: - # Calculate |x| = pointwise 2-norm of x - - tmp = diff[0] ** 2 - sq_tmp = x[0].space.element() - for x_i in diff[1:]: - x_i.multiply(x_i, out=sq_tmp) - tmp += sq_tmp - tmp.ufuncs.sqrt(out=tmp) - - # Pointwise maximum of |x| and lambda - tmp.ufuncs.maximum(lam, out=tmp) - - # Global scaling - tmp /= lam + # denom = max( |x-sig*g|_2, lam ) / lam (|.|_2 pointwise) + pwnorm = PointwiseNorm(self.domain, exponent=2) + denom = pwnorm(diff) + denom.ufuncs.maximum(lam, out=denom) + denom /= lam # Pointwise division - for out_i, x_i in zip(out, diff): - x_i.divide(tmp, out=out_i) + for out_i, diff_i in zip(out, diff): + diff_i.divide(denom, out=out_i) else: - # Calculate |x| = pointwise 2-norm of x + # out = max( |x-sig*g|, lam ) / lam diff.ufuncs.absolute(out=out) - - # Pointwise maximum of |x| and lambda out.ufuncs.maximum(lam, out=out) - - # Global scaling out /= lam - # Pointwise division + # out = diff / ... diff.divide(out, out=out) return ProximalConvexConjL1 @@ -1113,27 +1120,32 @@ def proximal_l1(space, lam=1, g=None, isotropic=False): F(x) = \\lambda \|x - g\|_1. For a step size :math:`\\sigma`, the proximal operator of :math:`\\sigma F` - is + is the "soft-shrinkage" operator .. math:: - \mathrm{prox}_{\\sigma F}(y) = \\begin{cases} - y - \\sigma \\lambda - & \\text{if } y > g + \\sigma \\lambda, \\\\ - 0 - & \\text{if } g - \\sigma \\lambda \\leq y \\leq g + - \\sigma \\lambda \\\\ - y + \\sigma \\lambda - & \\text{if } y < g - \\sigma \\lambda, + \mathrm{prox}_{\\sigma F}(x) = + \\begin{cases} + g, & \\text{where } |x - g| \\leq \sigma\\lambda, \\\\ + x - \sigma\\lambda \mathrm{sign}(x - g), & \\text{elsewhere.} \\end{cases} + Here, all operations are to be read pointwise. + An alternative formulation is available for `ProductSpace`'s, where the - the ``isotropic`` parameter can be used, giving + the ``isotropic`` parameter can be used, i.e., .. math:: - F(x) = \\lambda \| \|x - g\|_2 \|_1 + F(x) = \\lambda \| |x - g|_2 \|_1 - The proximal can be calculated using the Moreau equality (also known as - Moreau decomposition or Moreau identity). See for example [BC2011]. + with the pointwise Euclidean norm :math:`|\cdot|_2`. For this case, one + gets + + .. math:: + \mathrm{prox}_{\\sigma F}(x) = + \\begin{cases} + g, & \\text{where } |x - g|_2 \\leq \sigma\\lambda, \\\\ + x - \sigma\\lambda \\frac{x - g}{|x - g|_2}, & \\text{elsewhere.} + \\end{cases} See Also -------- @@ -1144,10 +1156,68 @@ def proximal_l1(space, lam=1, g=None, isotropic=False): [BC2011] Bauschke, H H, and Combettes, P L. *Convex analysis and monotone operator theory in Hilbert spaces*. Springer, 2011. """ - # TODO: optimize - prox_cc_l1 = proximal_convex_conj_l1(space, lam=lam, g=g, - isotropic=isotropic) - return proximal_convex_conj(prox_cc_l1) + lam = float(lam) + + if g is not None and g not in space: + raise TypeError('{!r} is not an element of {!r}'.format(g, space)) + + class ProximalL1(Operator): + + """Proximal operator of the l1-norm/distance.""" + + def __init__(self, sigma): + """Initialize a new instance. + + Parameters + ---------- + sigma : positive float + Step size parameter + """ + super(ProximalL1, self).__init__( + domain=space, range=space, linear=False) + self.sigma = float(sigma) + + def _call(self, x, out): + """Apply the operator to ``x`` and stores the result in ``out``.""" + # diff = x - g + if g is not None: + diff = x - g + else: + if x is out: + # Handle aliased data properly + diff = x.copy() + else: + diff = x + + if isotropic: + # We write the operator as + # x - (x - g) / max(|x - g|_2 / sig*lam, 1) + pwnorm = PointwiseNorm(self.domain, exponent=2) + denom = pwnorm(diff) + denom /= self.sigma * lam + denom.ufuncs.maximum(1, out=denom) + + # out = (x - g) / denom + for out_i, diff_i in zip(out, diff): + diff_i.divide(denom, out=out_i) + + # out = x - ... + out.lincomb(1, x, -1, out) + + else: + # We write the operator as + # x - (x - g) / max(|x - g| / sig*lam, 1) + denom = diff.ufuncs.absolute() + denom /= self.sigma * lam + denom.ufuncs.maximum(1, out=denom) + + # out = (x - g) / denom + diff.ufuncs.divide(denom, out=out) + + # out = x - ... + out.lincomb(1, x, -1, out) + + return ProximalL1 def proximal_convex_conj_kl(space, lam=1, g=None): @@ -1253,33 +1323,26 @@ def __init__(self, sigma): def _call(self, x, out): """Apply the operator to ``x`` and stores the result in ``out``.""" + # (x + lam - sqrt((x - lam)^2 + 4*lam*sig*g)) / 2 - # 1 / 2 (lam_X + x - sqrt((x - lam_X) ^ 2 + 4; lam sigma g) - - # out = x - lam_X + # out = (x - lam)^2 out.assign(x) out -= lam - - # (out)^2 out.ufuncs.square(out=out) - # out = out + 4 lam sigma g + # out = ... + 4*lam*sigma*g # If g is None, it is taken as the one element if g is None: out += 4.0 * lam * self.sigma else: out.lincomb(1, out, 4.0 * lam * self.sigma, g) - # out = sqrt(out) + # out = x - sqrt(...) + lam out.ufuncs.sqrt(out=out) - - # out = x - out out.lincomb(1, x, -1, out) + out += lam - # out = lam_X + out - out.lincomb(lam, space.one(), 1, out) - - # out = 1/2 * out + # out = 1/2 * ... out /= 2 return ProximalConvexConjKL diff --git a/odl/test/solvers/nonsmooth/proximal_operator_test.py b/odl/test/solvers/nonsmooth/proximal_operator_test.py index 3feb0fdca05..a3db607d67c 100644 --- a/odl/test/solvers/nonsmooth/proximal_operator_test.py +++ b/odl/test/solvers/nonsmooth/proximal_operator_test.py @@ -10,7 +10,6 @@ from __future__ import division import numpy as np -import pytest import scipy.special import odl