diff --git a/odl/solvers/smooth/gradient.py b/odl/solvers/smooth/gradient.py index 3866a732abb..6f2d44acf01 100644 --- a/odl/solvers/smooth/gradient.py +++ b/odl/solvers/smooth/gradient.py @@ -27,7 +27,7 @@ from odl.solvers.util import ConstantLineSearch -__all__ = ('steepest_descent',) +__all__ = ('steepest_descent', 'adam') # TODO: update all docs @@ -110,6 +110,77 @@ def steepest_descent(f, x, line_search=1.0, maxiter=1000, tol=1e-16, callback(x) +def adam(f, x, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, + maxiter=1000, tol=1e-16, callback=None): + """ADAM method to minimize an objective function. + + General implementation of ADAM for solving + + .. math:: + \min f(x) + + The algorithm is intended for unconstrained problems. + + The algorithm is described in + `Adam: A Method for Stochastic Optimization + `_. All parameter names are taken from + that article. + + Parameters + ---------- + f : `Functional` + Goal functional. Needs to have ``f.gradient``. + x : ``f.domain`` element + Starting point of the iteration + learning_rate : float, optional + Step length of the method. + beta1 : float, optional + Update rate for first order moment estimate. + beta2 : float, optional + Update rate for second order moment estimate. + eps : float, optional + A small constant for numerical stability. + maxiter : int, optional + Maximum number of iterations. + tol : float, optional + Tolerance that should be used for terminating the iteration. + callback : callable, optional + Object executing code per iteration, e.g. plotting each iterate + + See Also + -------- + odl.solvers.smooth.gradient.steepest_descent : simple steepest descent + odl.solvers.iterative.iterative.landweber : + Optimized solver for the case ``f(x) = ||Ax - b||_2^2`` + odl.solvers.iterative.iterative.conjugate_gradient : + Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b`` + """ + grad = f.gradient + if x not in grad.domain: + raise TypeError('`x` {!r} is not in the domain of `grad` {!r}' + ''.format(x, grad.domain)) + + m = grad.domain.zero() + v = grad.domain.zero() + + grad_x = grad.range.element() + for _ in range(maxiter): + grad(x, out=grad_x) + + if grad_x.norm() < tol: + return + + m.lincomb(beta1, m, 1 - beta1, grad_x) + v.lincomb(beta2, v, 1 - beta2, grad_x ** 2) + + step = learning_rate * np.sqrt(1 - beta2) / (1 - beta1) + + x.lincomb(1, x, -step, m / np.sqrt(v + eps)) + + if callback is not None: + callback(x) + + if __name__ == '__main__': # pylint: disable=wrong-import-position from odl.util.testutils import run_doctests diff --git a/odl/test/solvers/iterative/iterative_test.py b/odl/test/solvers/iterative/iterative_test.py index 4d5b1934abd..f71d95fd8f3 100644 --- a/odl/test/solvers/iterative/iterative_test.py +++ b/odl/test/solvers/iterative/iterative_test.py @@ -31,6 +31,7 @@ # Find the valid projectors @pytest.fixture(scope="module", params=['steepest_descent', + 'adam', 'landweber', 'conjugate_gradient', 'conjugate_gradient_normal', @@ -47,6 +48,12 @@ def solver(op, x, rhs): func = odl.solvers.L2NormSquared(op.domain) * (op - rhs) odl.solvers.steepest_descent(func, x, line_search=0.5 / norm2) + elif solver_name == 'adam': + def solver(op, x, rhs): + norm2 = op.adjoint(op(x)).norm() / x.norm() + func = odl.solvers.L2NormSquared(op.domain) * (op - rhs) + + odl.solvers.adam(func, x, learning_rate=4.0 / norm2, maxiter=150) elif solver_name == 'landweber': def solver(op, x, rhs): norm2 = op.adjoint(op(x)).norm() / x.norm() diff --git a/odl/test/solvers/smooth/smooth_test.py b/odl/test/solvers/smooth/smooth_test.py index 5eb65fed666..6eb7b5972a8 100644 --- a/odl/test/solvers/smooth/smooth_test.py +++ b/odl/test/solvers/smooth/smooth_test.py @@ -136,6 +136,15 @@ def test_steepest_descent(functional): assert functional(x) < 1e-3 +def test_adam(functional): + """Test the ``adam`` solver.""" + + x = functional.domain.one() + odl.solvers.adam(functional, x, tol=1e-2, learning_rate=0.5) + + assert functional(x) < 1e-3 + + def test_conjguate_gradient_nonlinear(functional, nonlinear_cg_beta): """Test the ``conjugate_gradient_nonlinear`` solver.""" line_search = odl.solvers.BacktrackingLineSearch(functional)