-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ADMM #1198
Merged
Merged
Add ADMM #1198
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
ae8e8a7
ENH: add linearized ADMM solver
08fad80
ENH: add ADMM tomography example
cd97c4a
ENH: make ADMM use only 1 operator evaluation
0c4117f
MAINT: use larger step for callbacks in tomography solver examples
bc80058
MAINT: improve input handling of ADMM
a5edbf0
TST: add some simple unit tests for ADMM
570d715
ENH: improve doc and comments of ADMM example
51dbab4
MAINT: minor fixes to admm_linearized solver and example
b3f3e81
TST: simplify CallbackTest in ADMM test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Total variation tomography using linearized ADMM. | ||
|
||
In this example we solve the optimization problem | ||
|
||
min_x ||A(x) - y||_2^2 + lam * ||grad(x)||_1 | ||
|
||
Where ``A`` is a parallel beam ray transform, ``grad`` the spatial | ||
gradient and ``y`` given noisy data. | ||
|
||
The problem is rewritten in decoupled form as | ||
|
||
min_x g(L(x)) | ||
|
||
with a separable sum ``g`` of functionals and the stacked operator ``L``: | ||
|
||
g(z) = ||z_1 - g||_2^2 + lam * ||z_2||_1, | ||
|
||
( A(x) ) | ||
z = L(x) = ( grad(x) ). | ||
|
||
See the documentation of the `admm_linearized` solver for further details. | ||
""" | ||
|
||
import numpy as np | ||
import odl | ||
|
||
# --- Set up the forward operator (ray transform) --- # | ||
|
||
# Reconstruction space: functions on the rectangle [-20, 20]^2 | ||
# discretized with 300 samples per dimension | ||
reco_space = odl.uniform_discr( | ||
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32') | ||
|
||
# Make a parallel beam geometry with flat detector, using 360 angles | ||
geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=180) | ||
|
||
# Create the forward operator | ||
ray_trafo = odl.tomo.RayTransform(reco_space, geometry) | ||
|
||
# --- Generate artificial data --- # | ||
|
||
# Create phantom and noisy projection data | ||
phantom = odl.phantom.shepp_logan(reco_space, modified=True) | ||
data = ray_trafo(phantom) | ||
data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1 | ||
|
||
# --- Set up the inverse problem --- # | ||
|
||
# Gradient operator for the TV part | ||
grad = odl.Gradient(reco_space) | ||
|
||
# Stacking of the two operators | ||
L = odl.BroadcastOperator(ray_trafo, grad) | ||
|
||
# Data matching and regularization functionals | ||
data_fit = odl.solvers.L2NormSquared(ray_trafo.range).translated(data) | ||
reg_func = 0.015 * odl.solvers.L1Norm(grad.range) | ||
g = odl.solvers.SeparableSum(data_fit, reg_func) | ||
|
||
# We don't use the f functional, setting it to zero | ||
f = odl.solvers.ZeroFunctional(L.domain) | ||
|
||
# --- Select parameters and solve using ADMM --- # | ||
|
||
# Estimated operator norm, add 10 percent for some safety margin | ||
op_norm = 1.1 * odl.power_method_opnorm(L, maxiter=20) | ||
|
||
niter = 200 # Number of iterations | ||
sigma = 2.0 # Step size for g.proximal | ||
tau = sigma / op_norm ** 2 # Step size for f.proximal | ||
|
||
# Optionally pass a callback to the solver to display intermediate results | ||
callback = (odl.solvers.CallbackPrintIteration(step=10) & | ||
odl.solvers.CallbackShow(step=10)) | ||
|
||
# Choose a starting point | ||
x = L.domain.zero() | ||
|
||
# Run the algorithm | ||
odl.solvers.admm_linearized(x, f, g, L, tau, sigma, niter, callback=callback) | ||
|
||
# Display images | ||
phantom.show(title='Phantom') | ||
data.show(title='Simulated data (Sinogram)') | ||
x.show(title='TV reconstruction', force_show=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Alternating Direction method of Multipliers (ADMM) method variants.""" | ||
|
||
from __future__ import division | ||
from odl.operator import Operator, OpDomainError | ||
|
||
|
||
__all__ = ('admm_linearized',) | ||
|
||
|
||
def admm_linearized(x, f, g, L, tau, sigma, niter, **kwargs): | ||
"""Generic linearized ADMM method for convex problems. | ||
|
||
ADMM stands for "Alternating Direction Method of Multipliers" and | ||
is a popular convex optimization method. This variant solves problems | ||
of the form :: | ||
|
||
min_x [ f(x) + g(Lx) ] | ||
|
||
with convex ``f`` and ``g``, and a linear operator ``L``. See Section | ||
4.4 of `[PB2014] <http://web.stanford.edu/~boyd/papers/prox_algs.html>`_ | ||
and the Notes for more mathematical details. | ||
|
||
Parameters | ||
---------- | ||
x : ``L.domain`` element | ||
Starting point of the iteration, updated in-place. | ||
f, g : `Functional` | ||
The functions ``f`` and ``g`` in the problem definition. They | ||
need to implement the ``proximal`` method. | ||
L : linear `Operator` | ||
The linear operator that is composed with ``g`` in the problem | ||
definition. It must fulfill ``L.domain == f.domain`` and | ||
``L.range == g.domain``. | ||
tau, sigma : positive float | ||
Step size parameters for the update of the variables. | ||
niter : non-negative int | ||
Number of iterations. | ||
|
||
Other Parameters | ||
---------------- | ||
callback : callable, optional | ||
Function called with the current iterate after each iteration. | ||
|
||
Notes | ||
----- | ||
Given :math:`x^{(0)}` (the provided ``x``) and | ||
:math:`u^{(0)} = z^{(0)} = 0`, linearized ADMM applies the following | ||
iteration: | ||
|
||
.. math:: | ||
x^{(k+1)} &= \mathrm{prox}_{\\tau f} \\left[ | ||
x^{(k)} - \sigma^{-1}\\tau L^*\\big( | ||
L x^{(k)} - z^{(k)} + u^{(k)} | ||
\\big) | ||
\\right] | ||
|
||
z^{(k+1)} &= \mathrm{prox}_{\sigma g}\\left( | ||
L x^{(k+1)} + u^{(k)} | ||
\\right) | ||
|
||
u^{(k+1)} &= u^{(k)} + L x^{(k+1)} - z^{(k+1)} | ||
|
||
The step size parameters :math:`\\tau` and :math:`\sigma` must satisfy | ||
|
||
.. math:: | ||
0 < \\tau < \\frac{\sigma}{\|L\|^2} | ||
|
||
to guarantee convergence. | ||
|
||
The name "linearized ADMM" comes from the fact that in the | ||
minimization subproblem for the :math:`x` variable, this variant | ||
uses a linearization of a quadratic term in the augmented Lagrangian | ||
of the generic ADMM, in order to make the step expressible with | ||
the proximal operator of :math:`f`. | ||
|
||
Another name for this algorithm is *split inexact Uzawa method*. | ||
|
||
References | ||
---------- | ||
[PB2014] Parikh, N and Boyd, S. *Proximal Algorithms*. Foundations and | ||
Trends in Optimization, 1(3) (2014), pp 123-231. | ||
""" | ||
if not isinstance(L, Operator): | ||
raise TypeError('`op` {!r} is not an `Operator` instance' | ||
''.format(L)) | ||
|
||
if x not in L.domain: | ||
raise OpDomainError('`x` {!r} is not in the domain of `op` {!r}' | ||
''.format(x, L.domain)) | ||
|
||
tau, tau_in = float(tau), tau | ||
if tau <= 0: | ||
raise ValueError('`tau` must be positive, got {}'.format(tau_in)) | ||
|
||
sigma, sigma_in = float(sigma), sigma | ||
if sigma <= 0: | ||
raise ValueError('`sigma` must be positive, got {}'.format(sigma_in)) | ||
|
||
niter, niter_in = int(niter), niter | ||
if niter < 0 or niter != niter_in: | ||
raise ValueError('`niter` must be a non-negative integer, got {}' | ||
''.format(niter_in)) | ||
|
||
# Callback object | ||
callback = kwargs.pop('callback', None) | ||
if callback is not None and not callable(callback): | ||
raise TypeError('`callback` {} is not callable'.format(callback)) | ||
|
||
# Initialize range variables | ||
z = L.range.zero() | ||
u = L.range.zero() | ||
|
||
# Temporary for Lx + u [- z] | ||
tmp_ran = L(x) | ||
# Temporary for L^*(Lx + u - z) | ||
tmp_dom = L.domain.element() | ||
|
||
# Store proximals since their initialization may involve computation | ||
prox_tau_f = f.proximal(tau) | ||
prox_sigma_g = g.proximal(sigma) | ||
|
||
for _ in range(niter): | ||
# tmp_ran has value Lx^k here | ||
# tmp_dom <- L^*(Lx^k + u^k - z^k) | ||
tmp_ran += u | ||
tmp_ran -= z | ||
L.adjoint(tmp_ran, out=tmp_dom) | ||
|
||
# x <- x^k - (tau/sigma) L^*(Lx^k + u^k - z^k) | ||
x.lincomb(1, x, -tau / sigma, tmp_dom) | ||
# x^(k+1) <- prox[tau*f](x) | ||
prox_tau_f(x, out=x) | ||
|
||
# tmp_ran <- Lx^(k+1) | ||
L(x, out=tmp_ran) | ||
# z^(k+1) <- prox[sigma*g](Lx^(k+1) + u^k) | ||
prox_sigma_g(tmp_ran + u, out=z) # 1 copy here | ||
|
||
# u^(k+1) = u^k + Lx^(k+1) - z^(k+1) | ||
u += tmp_ran | ||
u -= z | ||
|
||
if callback is not None: | ||
callback(x) | ||
|
||
|
||
def admm_linearized_simple(x, f, g, L, tau, sigma, niter, **kwargs): | ||
"""Non-optimized version of ``admm_linearized``. | ||
|
||
This function is intended for debugging. It makes a lot of copies and | ||
performs no error checking. | ||
""" | ||
callback = kwargs.pop('callback', None) | ||
z = L.range.zero() | ||
u = L.range.zero() | ||
for _ in range(niter): | ||
x[:] = f.proximal(tau)(x - tau / sigma * L.adjoint(L(x) + u - z)) | ||
z = g.proximal(sigma)(L(x) + u) | ||
u = L(x) + u - z | ||
if callback is not None: | ||
callback(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2014-2017 The ODL contributors | ||
# | ||
# This file is part of ODL. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public License, | ||
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | ||
# obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
"""Unit tests for ADMM.""" | ||
|
||
from __future__ import division | ||
import odl | ||
from odl.solvers import admm_linearized, Callback | ||
|
||
from odl.util.testutils import all_almost_equal, noise_element | ||
|
||
|
||
def test_admm_lin_input_handling(): | ||
"""Test to see that input is handled correctly.""" | ||
|
||
space = odl.uniform_discr(0, 1, 10) | ||
|
||
L = odl.ZeroOperator(space) | ||
f = g = odl.solvers.ZeroFunctional(space) | ||
|
||
# Check that the algorithm runs. With the above operators and functionals, | ||
# the algorithm should not modify the initial value. | ||
x0 = noise_element(space) | ||
x = x0.copy() | ||
niter = 3 | ||
|
||
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter) | ||
|
||
assert x == x0 | ||
|
||
# Check that a provided callback is actually called | ||
class CallbackTest(Callback): | ||
was_called = False | ||
|
||
def __call__(self, *args, **kwargs): | ||
self.was_called = True | ||
|
||
callback = CallbackTest() | ||
assert not callback.was_called | ||
admm_linearized(x, f, g, L, tau=1.0, sigma=1.0, niter=niter, | ||
callback=callback) | ||
assert callback.was_called | ||
|
||
|
||
def test_admm_lin_l1(): | ||
"""Verify that the correct value is returned for l1 dist optimization. | ||
|
||
Solves the optimization problem | ||
|
||
min_x ||x - data_1||_1 + 0.5 ||x - data_2||_1 | ||
|
||
which has optimum value data_1 since the first term dominates. | ||
""" | ||
space = odl.rn(5) | ||
|
||
L = odl.IdentityOperator(space) | ||
|
||
data_1 = odl.util.testutils.noise_element(space) | ||
data_2 = odl.util.testutils.noise_element(space) | ||
|
||
f = odl.solvers.L1Norm(space).translated(data_1) | ||
g = 0.5 * odl.solvers.L1Norm(space).translated(data_2) | ||
|
||
x = space.zero() | ||
admm_linearized(x, f, g, L, tau=1.0, sigma=2.0, niter=10) | ||
|
||
assert all_almost_equal(x, data_1, places=2) | ||
|
||
|
||
if __name__ == '__main__': | ||
odl.util.test_file(__file__) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.