From a188aa874295b30e865d746192529fc245b142ea Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Mon, 11 Mar 2019 23:27:10 +0100 Subject: [PATCH] WIP: make functionals and prox ops work for product spaces --- odl/solvers/functional/default_functionals.py | 28 +++++--- odl/solvers/functional/functional.py | 65 +++++++------------ odl/solvers/nonsmooth/proximal_operators.py | 12 ++-- 3 files changed, 49 insertions(+), 56 deletions(-) diff --git a/odl/solvers/functional/default_functionals.py b/odl/solvers/functional/default_functionals.py index 1f596dbc1c9..e1e7e7644f6 100644 --- a/odl/solvers/functional/default_functionals.py +++ b/odl/solvers/functional/default_functionals.py @@ -148,7 +148,10 @@ def __init__(self): def _call(self, x): """Apply the gradient operator to the given point.""" - return np.sign(x) + if isinstance(self.domain, ProductSpace): + return self.domain.apply(np.sign, x) + else: + return np.sign(x) def derivative(self, x): """Derivative is a.e. zero.""" @@ -1127,14 +1130,23 @@ def _call(self, x): import scipy.special if self.prior is None: - tmp = self.domain.inner(self.domain.one(), x - 1 - np.log(x)) + if isinstance(self.domain, ProductSpace): + log_x = self.domain.apply(np.log, x) + else: + log_x = np.log(x) + tmp = self.domain.inner(self.domain.one(), x - 1 - log_x) + else: - tmp = self.domain.inner( - self.domain.one(), - x - self.prior + scipy.special.xlogy( - self.prior, self.prior / x - ), - ) + g = self.prior + if isinstance(self.domain, ProductSpace): + xlogy = self.domain.apply2( + lambda v, i: scipy.special.xlogy(g[i], g[i] / v), x + ) + else: + xlogy = scipy.special.xlogy(g, g / x) + + tmp = self.domain.inner(self.domain.one(), x - g + xlogy) + if np.isnan(tmp): # In this case, some element was less than or equal to zero return np.inf diff --git a/odl/solvers/functional/functional.py b/odl/solvers/functional/functional.py index bf02aed3c2d..c550bb10f8c 100644 --- a/odl/solvers/functional/functional.py +++ b/odl/solvers/functional/functional.py @@ -8,18 +8,19 @@ # v. 2.0. If a copy of the MPL was not distributed with this file, You can # obtain one at https://mozilla.org/MPL/2.0/. -from __future__ import print_function, division, absolute_import +from __future__ import absolute_import, division, print_function + import numpy as np +from odl.operator.default_ops import ( + ConstantOperator, IdentityOperator, InnerProductOperator) from odl.operator.operator import ( - Operator, OperatorComp, OperatorLeftScalarMult, OperatorRightScalarMult, - OperatorRightVectorMult, OperatorSum, OperatorPointwiseProduct) -from odl.operator.default_ops import (IdentityOperator, ConstantOperator) -from odl.solvers.nonsmooth import (proximal_arg_scaling, proximal_translation, - proximal_quadratic_perturbation, - proximal_const_func, proximal_convex_conj) -from odl.util import signature_string, indent - + Operator, OperatorComp, OperatorLeftScalarMult, OperatorPointwiseProduct, + OperatorRightScalarMult, OperatorRightVectorMult, OperatorSum) +from odl.solvers.nonsmooth import ( + proximal_arg_scaling, proximal_const_func, proximal_convex_conj, + proximal_quadratic_perturbation, proximal_translation) +from odl.util import indent, signature_string __all__ = ('Functional', 'FunctionalLeftScalarMult', 'FunctionalRightScalarMult', 'FunctionalComp', @@ -204,7 +205,7 @@ def derivative(self, point): ------- derivative : `Operator` """ - return self.gradient(point).T + return InnerProductOperator(self.domain, self.gradient(point)) def translated(self, shift): """Return a translation of the functional. @@ -1399,33 +1400,18 @@ def __init__(self, functional, point, subgrad): raise TypeError('`functional` {} not an instance of ``Functional``' ''.format(functional)) self.__functional = functional - - if point not in functional.domain: - raise ValueError('`point` {} is not in `functional.domain` {}' - ''.format(point, functional.domain)) - self.__point = point - - if subgrad not in functional.domain: - raise TypeError( - '`subgrad` must be an element in `functional.domain`, got ' - '{}'.format(subgrad)) - self.__subgrad = subgrad - - self.__constant = ( - -functional(point) - + functional.domain.inner(subgrad, point) - ) - + space = functional.domain + self.__point = space.element(point) + self.__subgrad = space.element(subgrad) + self.__constant = -functional(point) + space.inner(subgrad, point) self.__bregman_dist = FunctionalQuadraticPerturb( - functional, linear_term=-subgrad, constant=self.__constant) - - grad_lipschitz = ( - functional.grad_lipschitz + functional.domain.norm(subgrad) + functional, linear_term=-subgrad, constant=self.__constant ) + grad_lipschitz = functional.grad_lipschitz + space.norm(subgrad) super(BregmanDistance, self).__init__( - space=functional.domain, linear=False, - grad_lipschitz=grad_lipschitz) + space, linear=False, grad_lipschitz=grad_lipschitz + ) @property def functional(self): @@ -1459,15 +1445,10 @@ def proximal(self): @property def gradient(self): """Gradient operator of the functional.""" - try: - op_to_return = self.functional.gradient - except NotImplementedError: - raise NotImplementedError( - '`self.functional.gradient` is not implemented for ' - '`self.functional` {}'.format(self.functional)) - - op_to_return = op_to_return - ConstantOperator(self.subgrad) - return op_to_return + return ( + self.functional.gradient + - ConstantOperator(self.domain, self.subgrad) + ) def __repr__(self): '''Return ``repr(self)``.''' diff --git a/odl/solvers/nonsmooth/proximal_operators.py b/odl/solvers/nonsmooth/proximal_operators.py index 9b48ed7895a..b535b397d0f 100644 --- a/odl/solvers/nonsmooth/proximal_operators.py +++ b/odl/solvers/nonsmooth/proximal_operators.py @@ -21,15 +21,15 @@ Foundations and Trends in Optimization, 1 (2014), pp 127-239. """ -from __future__ import print_function, division, absolute_import +from __future__ import absolute_import, division, print_function + import numpy as np from odl.operator import ( - Operator, IdentityOperator, ConstantOperator, DiagonalOperator, - PointwiseNorm, MultiplyOperator) + ConstantOperator, DiagonalOperator, IdentityOperator, MultiplyOperator, + Operator, PointwiseNorm) from odl.space import ProductSpace - __all__ = ('combine_proximals', 'proximal_convex_conj', 'proximal_translation', 'proximal_arg_scaling', 'proximal_quadratic_perturbation', 'proximal_composition', 'proximal_const_func', @@ -799,7 +799,7 @@ def _call(self, x, out): if step < 1.0: self.range.lincomb(1 - step, x, out=out) else: - out[:] = 0 + self.range.lincomb(0, out, out=out) else: x_norm = self.domain.norm(x - g) * (1 + eps) @@ -811,7 +811,7 @@ def _call(self, x, out): if step < 1.0: self.range.lincomb(1 - step, x, step, g, out=out) else: - out[:] = g + self.range.lincomb(1, g, out=out) return ProximalL2