diff --git a/odl/test/solvers/nonsmooth/admm_test.py b/odl/test/solvers/nonsmooth/admm_test.py index 5dc29e2ce55..dbab6f71921 100644 --- a/odl/test/solvers/nonsmooth/admm_test.py +++ b/odl/test/solvers/nonsmooth/admm_test.py @@ -11,6 +11,8 @@ from __future__ import division import odl from odl.solvers import admm_linearized, admm_precon_nonlinear, Callback +from odl.solvers.nonsmooth.admm import ( + admm_linearized_simple, admm_precon_nonlinear_simple) from odl.util.testutils import all_almost_equal, noise_element @@ -72,6 +74,22 @@ def test_admm_lin_l1(): assert all_almost_equal(x, data_1, places=2) +def test_admm_lin_vs_simple(): + """Check ``admm_linearized`` against the simple implementation.""" + space = odl.rn(5) + L = odl.ScalingOperator(space, 2) + y = L(odl.util.testutils.noise_element(space)) + f = odl.solvers.L2NormSquared(space).translated(y) + g = 0.5 * odl.solvers.L1Norm(space) + + x_simple = space.zero() + admm_linearized_simple(x_simple, f, g, L, tau=1.0, sigma=2.0, niter=10) + x_optim = space.zero() + admm_linearized(x_optim, f, g, L, tau=1.0, sigma=2.0, niter=10) + + assert all_almost_equal(x_optim, x_simple) + + def test_admm_nonlin_affine_l1(): """Verify that the correct value is returned for l1 dist optimization. @@ -100,5 +118,21 @@ def test_admm_nonlin_affine_l1(): assert all_almost_equal(x, x, places=2) +def test_admm_nonlin_vs_simple(): + """Check ``admm_precon_nonlinear`` against the simple implementation.""" + space = odl.rn(5) + L = odl.ufunc_ops.sin(space) * odl.ScalingOperator(space, 2) + y = L(odl.util.testutils.noise_element(space)) + f = odl.solvers.L2NormSquared(space).translated(y) + g = 0.5 * odl.solvers.L1Norm(space) + + x_simple = space.zero() + admm_precon_nonlinear_simple(x_simple, f, g, L, delta=1.0, niter=2) + x_optim = space.zero() + admm_precon_nonlinear(x_optim, f, g, L, delta=1.0, niter=2) + + assert all_almost_equal(x_optim, x_simple) + + if __name__ == '__main__': odl.util.test_file(__file__)