Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ADMM #1198

Merged
merged 9 commits into from
Oct 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/solvers/admm_tomography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Total variation tomography using linearized ADMM.

In this example we solve the optimization problem

min_x ||A(x) - y||_2^2 + lam * ||grad(x)||_1

Where ``A`` is a parallel beam ray transform, ``grad`` the spatial
gradient and ``y`` given noisy data.

The problem is rewritten in decoupled form as

min_x g(L(x))

with a separable sum ``g`` of functionals and the stacked operator ``L``:

g(z) = ||z_1 - g||_2^2 + lam * ||z_2||_1,

( A(x) )
z = L(x) = ( grad(x) ).

See the documentation of the `admm_linearized` solver for further details.
"""

import numpy as np
import odl

# --- Set up the forward operator (ray transform) --- #

# Reconstruction space: functions on the rectangle [-20, 20]^2
# discretized with 300 samples per dimension
reco_space = odl.uniform_discr(
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32')

# Make a parallel beam geometry with flat detector, using 360 angles
geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=180)

# Create the forward operator
ray_trafo = odl.tomo.RayTransform(reco_space, geometry)

# --- Generate artificial data --- #

# Create phantom and noisy projection data
phantom = odl.phantom.shepp_logan(reco_space, modified=True)
data = ray_trafo(phantom)
data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1

# --- Set up the inverse problem --- #

# Gradient operator for the TV part
grad = odl.Gradient(reco_space)

# Stacking of the two operators
L = odl.BroadcastOperator(ray_trafo, grad)

# Data matching and regularization functionals
data_fit = odl.solvers.L2NormSquared(ray_trafo.range).translated(data)
reg_func = 0.015 * odl.solvers.L1Norm(grad.range)
g = odl.solvers.SeparableSum(data_fit, reg_func)

# We don't use the f functional, setting it to zero
f = odl.solvers.ZeroFunctional(L.domain)

# --- Select parameters and solve using ADMM --- #

# Estimated operator norm, add 10 percent for some safety margin
op_norm = 1.1 * odl.power_method_opnorm(L, maxiter=20)

niter = 200 # Number of iterations
sigma = 2.0 # Step size for g.proximal
tau = sigma / op_norm ** 2 # Step size for f.proximal

# Optionally pass a callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Choose a starting point
x = L.domain.zero()

# Run the algorithm
odl.solvers.admm_linearized(x, f, g, L, tau, sigma, niter, callback=callback)

# Display images
phantom.show(title='Phantom')
data.show(title='Simulated data (Sinogram)')
x.show(title='TV reconstruction', force_show=True)
4 changes: 2 additions & 2 deletions examples/solvers/lbfgs_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
hessinv_estimate = odl.ScalingOperator(reco_space, 1 / opnorm ** 2)

# Optionally pass callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow())
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Pick parameters
maxiter = 20
Expand Down
4 changes: 2 additions & 2 deletions examples/solvers/pdhg_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
sigma = 1.0 / op_norm # Step size for the dual variable

# Optionally pass callback to the solver to display intermediate results
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow())
callback = (odl.solvers.CallbackPrintIteration(step=10) &
odl.solvers.CallbackShow(step=10))

# Choose a starting point
x = op.domain.zero()
Expand Down
3 changes: 3 additions & 0 deletions odl/solvers/nonsmooth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .proximal_operators import *
__all__ += proximal_operators.__all__

from .admm import *
__all__ += admm.__all__

from .primal_dual_hybrid_gradient import *
__all__ += primal_dual_hybrid_gradient.__all__

Expand Down
161 changes: 161 additions & 0 deletions odl/solvers/nonsmooth/admm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Alternating Direction method of Multipliers (ADMM) method variants."""

from __future__ import division
from odl.operator import Operator, OpDomainError


__all__ = ('admm_linearized',)


def admm_linearized(x, f, g, L, tau, sigma, niter, **kwargs):
"""Generic linearized ADMM method for convex problems.

ADMM stands for "Alternating Direction Method of Multipliers" and
is a popular convex optimization method. This variant solves problems
of the form ::

min_x [ f(x) + g(Lx) ]

with convex ``f`` and ``g``, and a linear operator ``L``. See Section
4.4 of `[PB2014] <http://web.stanford.edu/~boyd/papers/prox_algs.html>`_
and the Notes for more mathematical details.

Parameters
----------
x : ``L.domain`` element
Starting point of the iteration, updated in-place.
f, g : `Functional`
The functions ``f`` and ``g`` in the problem definition. They
need to implement the ``proximal`` method.
L : linear `Operator`
The linear operator that is composed with ``g`` in the problem
definition. It must fulfill ``L.domain == f.domain`` and
``L.range == g.domain``.
tau, sigma : positive float
Step size parameters for the update of the variables.
niter : non-negative int
Number of iterations.

Other Parameters
----------------
callback : callable, optional
Function called with the current iterate after each iteration.

Notes
-----
Given :math:`x^{(0)}` (the provided ``x``) and
:math:`u^{(0)} = z^{(0)} = 0`, linearized ADMM applies the following
iteration:

.. math::
x^{(k+1)} &= \mathrm{prox}_{\\tau f} \\left[
x^{(k)} - \sigma^{-1}\\tau L^*\\big(
L x^{(k)} - z^{(k)} + u^{(k)}
\\big)
\\right]

z^{(k+1)} &= \mathrm{prox}_{\sigma g}\\left(
L x^{(k+1)} + u^{(k)}
\\right)

u^{(k+1)} &= u^{(k)} + L x^{(k+1)} - z^{(k+1)}

The step size parameters :math:`\\tau` and :math:`\sigma` must satisfy

.. math::
0 < \\tau < \\frac{\sigma}{\|L\|^2}

to guarantee convergence.

The name "linearized ADMM" comes from the fact that in the
minimization subproblem for the :math:`x` variable, this variant
uses a linearization of a quadratic term in the augmented Lagrangian
of the generic ADMM, in order to make the step expressible with
the proximal operator of :math:`f`.

Another name for this algorithm is *split inexact Uzawa method*.

References
----------
[PB2014] Parikh, N and Boyd, S. *Proximal Algorithms*. Foundations and
Trends in Optimization, 1(3) (2014), pp 123-231.
"""
if not isinstance(L, Operator):
raise TypeError('`op` {!r} is not an `Operator` instance'
''.format(L))

if x not in L.domain:
raise OpDomainError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, L.domain))

tau, tau_in = float(tau), tau
if tau <= 0:
raise ValueError('`tau` must be positive, got {}'.format(tau_in))

sigma, sigma_in = float(sigma), sigma
if sigma <= 0:
raise ValueError('`sigma` must be positive, got {}'.format(sigma_in))

niter, niter_in = int(niter), niter
if niter < 0 or niter != niter_in:
raise ValueError('`niter` must be a non-negative integer, got {}'
''.format(niter_in))

# Callback object
callback = kwargs.pop('callback', None)
if callback is not None and not callable(callback):
raise TypeError('`callback` {} is not callable'.format(callback))

# Initialize range variables
z = L.range.zero()
u = L.range.zero()

# Temporary for Lx + u [- z]
tmp_ran = L(x)
# Temporary for L^*(Lx + u - z)
tmp_dom = L.domain.element()

# Store proximals since their initialization may involve computation
prox_tau_f = f.proximal(tau)
prox_sigma_g = g.proximal(sigma)

for _ in range(niter):
# tmp_ran has value Lx^k here
# tmp_dom <- L^*(Lx^k + u^k - z^k)
tmp_ran += u
tmp_ran -= z
L.adjoint(tmp_ran, out=tmp_dom)

# x <- x^k - (tau/sigma) L^*(Lx^k + u^k - z^k)
x.lincomb(1, x, -tau / sigma, tmp_dom)
# x^(k+1) <- prox[tau*f](x)
prox_tau_f(x, out=x)

# tmp_ran <- Lx^(k+1)
L(x, out=tmp_ran)
# z^(k+1) <- prox[sigma*g](Lx^(k+1) + u^k)
prox_sigma_g(tmp_ran + u, out=z) # 1 copy here

# u^(k+1) = u^k + Lx^(k+1) - z^(k+1)
u += tmp_ran
u -= z

if callback is not None:
callback(x)


def admm_linearized_simple(x, f, g, L, tau, sigma, niter, **kwargs):
"""Non-optimized version of ``admm_linearized``.

This function is intended for debugging. It makes a lot of copies and
performs no error checking.
"""
callback = kwargs.pop('callback', None)
z = L.range.zero()
u = L.range.zero()
for _ in range(niter):
x[:] = f.proximal(tau)(x - tau / sigma * L.adjoint(L(x) + u - z))
z = g.proximal(sigma)(L(x) + u)
u = L(x) + u - z
if callback is not None:
callback(x)
76 changes: 76 additions & 0 deletions odl/test/solvers/nonsmooth/admm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2014-2017 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Unit tests for ADMM."""

from __future__ import division
import odl
from odl.solvers import admm_linearized, Callback

from odl.util.testutils import all_almost_equal, noise_element


def test_admm_lin_input_handling():
"""Test to see that input is handled correctly."""

space = odl.uniform_discr(0, 1, 10)

L = odl.ZeroOperator(space)
f = g = odl.solvers.ZeroFunctional(space)

# Check that the algorithm runs. With the above operators and functionals,
# the algorithm should not modify the initial value.
x0 = noise_element(space)
x = x0.copy()
niter = 3

admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter)

assert x == x0

# Check that a provided callback is actually called
class CallbackTest(Callback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class CallbackTest(Callback):
    was_called = False
    def __call__(self, *args, **kwargs):
        self.was_called = True

was_called = False

def __call__(self, *args, **kwargs):
self.was_called = True

callback = CallbackTest()
assert not callback.was_called
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter,
callback=callback)
assert callback.was_called


def test_admm_lin_l1():
"""Verify that the correct value is returned for l1 dist optimization.

Solves the optimization problem

min_x ||x - data_1||_1 + 0.5 ||x - data_2||_1

which has optimum value data_1 since the first term dominates.
"""
space = odl.rn(5)

L = odl.IdentityOperator(space)

data_1 = odl.util.testutils.noise_element(space)
data_2 = odl.util.testutils.noise_element(space)

f = odl.solvers.L1Norm(space).translated(data_1)
g = 0.5 * odl.solvers.L1Norm(space).translated(data_2)

x = space.zero()
admm_linearized(x, f, g, L, tau=1.0, sigma=2.0, niter=10)

assert all_almost_equal(x, data_1, places=2)


if __name__ == '__main__':
odl.util.test_file(__file__)