diff --git a/examples/solvers/admm_tomography.py b/examples/solvers/admm_tomography.py new file mode 100644 index 00000000000..47a99631e79 --- /dev/null +++ b/examples/solvers/admm_tomography.py @@ -0,0 +1,85 @@ +"""Total variation tomography using linearized ADMM. + +In this example we solve the optimization problem + + min_x ||A(x) - y||_2^2 + lam * ||grad(x)||_1 + +Where ``A`` is a parallel beam ray transform, ``grad`` the spatial +gradient and ``y`` given noisy data. + +The problem is rewritten in decoupled form as + + min_x g(L(x)) + +with a separable sum ``g`` of functionals and the stacked operator ``L``: + + g(z) = ||z_1 - g||_2^2 + lam * ||z_2||_1, + + ( A(x) ) + z = L(x) = ( grad(x) ). + +See the documentation of the `admm_linearized` solver for further details. +""" + +import numpy as np +import odl + +# --- Set up the forward operator (ray transform) --- # + +# Reconstruction space: functions on the rectangle [-20, 20]^2 +# discretized with 300 samples per dimension +reco_space = odl.uniform_discr( + min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32') + +# Make a parallel beam geometry with flat detector, using 360 angles +geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=180) + +# Create the forward operator +ray_trafo = odl.tomo.RayTransform(reco_space, geometry) + +# --- Generate artificial data --- # + +# Create phantom and noisy projection data +phantom = odl.phantom.shepp_logan(reco_space, modified=True) +data = ray_trafo(phantom) +data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1 + +# --- Set up the inverse problem --- # + +# Gradient operator for the TV part +grad = odl.Gradient(reco_space) + +# Stacking of the two operators +L = odl.BroadcastOperator(ray_trafo, grad) + +# Data matching and regularization functionals +data_fit = odl.solvers.L2NormSquared(ray_trafo.range).translated(data) +reg_func = 0.015 * odl.solvers.L1Norm(grad.range) +g = odl.solvers.SeparableSum(data_fit, reg_func) + +# We don't use the f functional, setting it to zero +f = odl.solvers.ZeroFunctional(L.domain) + +# --- Select parameters and solve using ADMM --- # + +# Estimated operator norm, add 10 percent for some safety margin +op_norm = 1.1 * odl.power_method_opnorm(L, maxiter=20) + +niter = 200 # Number of iterations +sigma = 2.0 # Step size for g.proximal +tau = sigma / op_norm ** 2 # Step size for f.proximal + +# Optionally pass a callback to the solver to display intermediate results +callback = (odl.solvers.CallbackPrintIteration(step=10) & + odl.solvers.CallbackShow(step=10)) + +# Choose a starting point +x = L.domain.zero() + +# Run the algorithm +odl.solvers.admm_linearized(x, f, g, L, tau, sigma, niter, callback=callback) + +# Display images +phantom.show(title='Phantom') +data.show(title='Simulated data (Sinogram)') +x.show(title='TV reconstruction', force_show=True) diff --git a/examples/solvers/lbfgs_tomography.py b/examples/solvers/lbfgs_tomography.py index 8826742cc8b..5928d5c6f52 100644 --- a/examples/solvers/lbfgs_tomography.py +++ b/examples/solvers/lbfgs_tomography.py @@ -56,8 +56,8 @@ hessinv_estimate = odl.ScalingOperator(reco_space, 1 / opnorm ** 2) # Optionally pass callback to the solver to display intermediate results -callback = (odl.solvers.CallbackPrintIteration() & - odl.solvers.CallbackShow()) +callback = (odl.solvers.CallbackPrintIteration(step=10) & + odl.solvers.CallbackShow(step=10)) # Pick parameters maxiter = 20 diff --git a/examples/solvers/pdhg_tomography.py b/examples/solvers/pdhg_tomography.py index 5beeeb3ed73..610cdff6169 100644 --- a/examples/solvers/pdhg_tomography.py +++ b/examples/solvers/pdhg_tomography.py @@ -72,8 +72,8 @@ sigma = 1.0 / op_norm # Step size for the dual variable # Optionally pass callback to the solver to display intermediate results -callback = (odl.solvers.CallbackPrintIteration() & - odl.solvers.CallbackShow()) +callback = (odl.solvers.CallbackPrintIteration(step=10) & + odl.solvers.CallbackShow(step=10)) # Choose a starting point x = op.domain.zero() diff --git a/odl/solvers/nonsmooth/__init__.py b/odl/solvers/nonsmooth/__init__.py index 811163508ea..8b13c331289 100644 --- a/odl/solvers/nonsmooth/__init__.py +++ b/odl/solvers/nonsmooth/__init__.py @@ -15,6 +15,9 @@ from .proximal_operators import * __all__ += proximal_operators.__all__ +from .admm import * +__all__ += admm.__all__ + from .primal_dual_hybrid_gradient import * __all__ += primal_dual_hybrid_gradient.__all__ diff --git a/odl/solvers/nonsmooth/admm.py b/odl/solvers/nonsmooth/admm.py new file mode 100644 index 00000000000..b172b54238b --- /dev/null +++ b/odl/solvers/nonsmooth/admm.py @@ -0,0 +1,161 @@ +"""Alternating Direction method of Multipliers (ADMM) method variants.""" + +from __future__ import division +from odl.operator import Operator, OpDomainError + + +__all__ = ('admm_linearized',) + + +def admm_linearized(x, f, g, L, tau, sigma, niter, **kwargs): + """Generic linearized ADMM method for convex problems. + + ADMM stands for "Alternating Direction Method of Multipliers" and + is a popular convex optimization method. This variant solves problems + of the form :: + + min_x [ f(x) + g(Lx) ] + + with convex ``f`` and ``g``, and a linear operator ``L``. See Section + 4.4 of `[PB2014] `_ + and the Notes for more mathematical details. + + Parameters + ---------- + x : ``L.domain`` element + Starting point of the iteration, updated in-place. + f, g : `Functional` + The functions ``f`` and ``g`` in the problem definition. They + need to implement the ``proximal`` method. + L : linear `Operator` + The linear operator that is composed with ``g`` in the problem + definition. It must fulfill ``L.domain == f.domain`` and + ``L.range == g.domain``. + tau, sigma : positive float + Step size parameters for the update of the variables. + niter : non-negative int + Number of iterations. + + Other Parameters + ---------------- + callback : callable, optional + Function called with the current iterate after each iteration. + + Notes + ----- + Given :math:`x^{(0)}` (the provided ``x``) and + :math:`u^{(0)} = z^{(0)} = 0`, linearized ADMM applies the following + iteration: + + .. math:: + x^{(k+1)} &= \mathrm{prox}_{\\tau f} \\left[ + x^{(k)} - \sigma^{-1}\\tau L^*\\big( + L x^{(k)} - z^{(k)} + u^{(k)} + \\big) + \\right] + + z^{(k+1)} &= \mathrm{prox}_{\sigma g}\\left( + L x^{(k+1)} + u^{(k)} + \\right) + + u^{(k+1)} &= u^{(k)} + L x^{(k+1)} - z^{(k+1)} + + The step size parameters :math:`\\tau` and :math:`\sigma` must satisfy + + .. math:: + 0 < \\tau < \\frac{\sigma}{\|L\|^2} + + to guarantee convergence. + + The name "linearized ADMM" comes from the fact that in the + minimization subproblem for the :math:`x` variable, this variant + uses a linearization of a quadratic term in the augmented Lagrangian + of the generic ADMM, in order to make the step expressible with + the proximal operator of :math:`f`. + + Another name for this algorithm is *split inexact Uzawa method*. + + References + ---------- + [PB2014] Parikh, N and Boyd, S. *Proximal Algorithms*. Foundations and + Trends in Optimization, 1(3) (2014), pp 123-231. + """ + if not isinstance(L, Operator): + raise TypeError('`op` {!r} is not an `Operator` instance' + ''.format(L)) + + if x not in L.domain: + raise OpDomainError('`x` {!r} is not in the domain of `op` {!r}' + ''.format(x, L.domain)) + + tau, tau_in = float(tau), tau + if tau <= 0: + raise ValueError('`tau` must be positive, got {}'.format(tau_in)) + + sigma, sigma_in = float(sigma), sigma + if sigma <= 0: + raise ValueError('`sigma` must be positive, got {}'.format(sigma_in)) + + niter, niter_in = int(niter), niter + if niter < 0 or niter != niter_in: + raise ValueError('`niter` must be a non-negative integer, got {}' + ''.format(niter_in)) + + # Callback object + callback = kwargs.pop('callback', None) + if callback is not None and not callable(callback): + raise TypeError('`callback` {} is not callable'.format(callback)) + + # Initialize range variables + z = L.range.zero() + u = L.range.zero() + + # Temporary for Lx + u [- z] + tmp_ran = L(x) + # Temporary for L^*(Lx + u - z) + tmp_dom = L.domain.element() + + # Store proximals since their initialization may involve computation + prox_tau_f = f.proximal(tau) + prox_sigma_g = g.proximal(sigma) + + for _ in range(niter): + # tmp_ran has value Lx^k here + # tmp_dom <- L^*(Lx^k + u^k - z^k) + tmp_ran += u + tmp_ran -= z + L.adjoint(tmp_ran, out=tmp_dom) + + # x <- x^k - (tau/sigma) L^*(Lx^k + u^k - z^k) + x.lincomb(1, x, -tau / sigma, tmp_dom) + # x^(k+1) <- prox[tau*f](x) + prox_tau_f(x, out=x) + + # tmp_ran <- Lx^(k+1) + L(x, out=tmp_ran) + # z^(k+1) <- prox[sigma*g](Lx^(k+1) + u^k) + prox_sigma_g(tmp_ran + u, out=z) # 1 copy here + + # u^(k+1) = u^k + Lx^(k+1) - z^(k+1) + u += tmp_ran + u -= z + + if callback is not None: + callback(x) + + +def admm_linearized_simple(x, f, g, L, tau, sigma, niter, **kwargs): + """Non-optimized version of ``admm_linearized``. + + This function is intended for debugging. It makes a lot of copies and + performs no error checking. + """ + callback = kwargs.pop('callback', None) + z = L.range.zero() + u = L.range.zero() + for _ in range(niter): + x[:] = f.proximal(tau)(x - tau / sigma * L.adjoint(L(x) + u - z)) + z = g.proximal(sigma)(L(x) + u) + u = L(x) + u - z + if callback is not None: + callback(x) diff --git a/odl/test/solvers/nonsmooth/admm_test.py b/odl/test/solvers/nonsmooth/admm_test.py new file mode 100644 index 00000000000..f7b37b401b9 --- /dev/null +++ b/odl/test/solvers/nonsmooth/admm_test.py @@ -0,0 +1,76 @@ +# Copyright 2014-2017 The ODL contributors +# +# This file is part of ODL. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. + +"""Unit tests for ADMM.""" + +from __future__ import division +import odl +from odl.solvers import admm_linearized, Callback + +from odl.util.testutils import all_almost_equal, noise_element + + +def test_admm_lin_input_handling(): + """Test to see that input is handled correctly.""" + + space = odl.uniform_discr(0, 1, 10) + + L = odl.ZeroOperator(space) + f = g = odl.solvers.ZeroFunctional(space) + + # Check that the algorithm runs. With the above operators and functionals, + # the algorithm should not modify the initial value. + x0 = noise_element(space) + x = x0.copy() + niter = 3 + + admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter) + + assert x == x0 + + # Check that a provided callback is actually called + class CallbackTest(Callback): + was_called = False + + def __call__(self, *args, **kwargs): + self.was_called = True + + callback = CallbackTest() + assert not callback.was_called + admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter, + callback=callback) + assert callback.was_called + + +def test_admm_lin_l1(): + """Verify that the correct value is returned for l1 dist optimization. + + Solves the optimization problem + + min_x ||x - data_1||_1 + 0.5 ||x - data_2||_1 + + which has optimum value data_1 since the first term dominates. + """ + space = odl.rn(5) + + L = odl.IdentityOperator(space) + + data_1 = odl.util.testutils.noise_element(space) + data_2 = odl.util.testutils.noise_element(space) + + f = odl.solvers.L1Norm(space).translated(data_1) + g = 0.5 * odl.solvers.L1Norm(space).translated(data_2) + + x = space.zero() + admm_linearized(x, f, g, L, tau=1.0, sigma=2.0, niter=10) + + assert all_almost_equal(x, data_1, places=2) + + +if __name__ == '__main__': + odl.util.test_file(__file__)