Skip to content

Commit

Permalink
ENH: Add ADAM solver, see #984
Browse files Browse the repository at this point in the history
  • Loading branch information
adler-j committed May 22, 2017
1 parent 009e19b commit 77d83d0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
73 changes: 72 additions & 1 deletion odl/solvers/smooth/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from odl.solvers.util import ConstantLineSearch


__all__ = ('steepest_descent',)
__all__ = ('steepest_descent', 'adam')


# TODO: update all docs
Expand Down Expand Up @@ -110,6 +110,77 @@ def steepest_descent(f, x, line_search=1.0, maxiter=1000, tol=1e-16,
callback(x)


def adam(f, x, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8,
maxiter=1000, tol=1e-16, callback=None):
"""ADAM method to minimize an objective function.
General implementation of ADAM for solving
.. math::
\min f(x)
The algorithm is intended for unconstrained problems.
The algorithm is described in
`Adam: A Method for Stochastic Optimization
<https://arxiv.org/abs/1412.6980>`_. All parameter names are taken from
that article.
Parameters
----------
f : `Functional`
Goal functional. Needs to have ``f.gradient``.
x : ``f.domain`` element
Starting point of the iteration
learning_rate : float, optional
Step length of the method.
beta1 : float, optional
Update rate for first order moment estimate.
beta2 : float, optional
Update rate for second order moment estimate.
eps : float, optional
A small constant for numerical stability.
maxiter : int, optional
Maximum number of iterations.
tol : float, optional
Tolerance that should be used for terminating the iteration.
callback : callable, optional
Object executing code per iteration, e.g. plotting each iterate
See Also
--------
odl.solvers.smooth.gradient.steepest_descent : simple steepest descent
odl.solvers.iterative.iterative.landweber :
Optimized solver for the case ``f(x) = ||Ax - b||_2^2``
odl.solvers.iterative.iterative.conjugate_gradient :
Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b``
"""
grad = f.gradient
if x not in grad.domain:
raise TypeError('`x` {!r} is not in the domain of `grad` {!r}'
''.format(x, grad.domain))

m = grad.domain.zero()
v = grad.domain.zero()

grad_x = grad.range.element()
for _ in range(maxiter):
grad(x, out=grad_x)

if grad_x.norm() < tol:
return

m.lincomb(beta1, m, 1 - beta1, grad_x)
v.lincomb(beta2, v, 1 - beta2, grad_x ** 2)

step = learning_rate * np.sqrt(1 - beta2) / (1 - beta1)

x.lincomb(1, x, -step, m / np.sqrt(v + eps))

if callback is not None:
callback(x)


if __name__ == '__main__':
# pylint: disable=wrong-import-position
from odl.util.testutils import run_doctests
Expand Down
7 changes: 7 additions & 0 deletions odl/test/solvers/iterative/iterative_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# Find the valid projectors
@pytest.fixture(scope="module",
params=['steepest_descent',
'adam',
'landweber',
'conjugate_gradient',
'conjugate_gradient_normal',
Expand All @@ -47,6 +48,12 @@ def solver(op, x, rhs):
func = odl.solvers.L2NormSquared(op.domain) * (op - rhs)

odl.solvers.steepest_descent(func, x, line_search=0.5 / norm2)
elif solver_name == 'adam':
def solver(op, x, rhs):
norm2 = op.adjoint(op(x)).norm() / x.norm()
func = odl.solvers.L2NormSquared(op.domain) * (op - rhs)

odl.solvers.adam(func, x, learning_rate=4.0 / norm2, maxiter=150)
elif solver_name == 'landweber':
def solver(op, x, rhs):
norm2 = op.adjoint(op(x)).norm() / x.norm()
Expand Down
9 changes: 9 additions & 0 deletions odl/test/solvers/smooth/smooth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def test_steepest_descent(functional):
assert functional(x) < 1e-3


def test_adam(functional):
"""Test the ``adam`` solver."""

x = functional.domain.one()
odl.solvers.adam(functional, x, tol=1e-2, learning_rate=0.5)

assert functional(x) < 1e-3


def test_conjguate_gradient_nonlinear(functional, nonlinear_cg_beta):
"""Test the ``conjugate_gradient_nonlinear`` solver."""
line_search = odl.solvers.BacktrackingLineSearch(functional)
Expand Down

0 comments on commit 77d83d0

Please sign in to comment.