From 13b27bf9a6563b78ddf70b87c388990a2142808e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 31 Oct 2023 22:38:04 -0400 Subject: [PATCH] Line Search for Gauss Newton --- src/gaussnewton.jl | 23 ++++++++++++++++------- src/linesearch.jl | 1 + src/raphson.jl | 2 +- test/nonlinear_least_squares.jl | 18 ++++++++++-------- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index c857f2d23..012767dcf 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -1,5 +1,5 @@ """ - GaussNewton(; concrete_jac = nothing, linsolve = nothing, + GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) An advanced GaussNewton implementation with support for efficient handling of sparse @@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems. preconditioners. For more information on specifying preconditioners for LinearSolve algorithms, consult the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). + - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), + which means that no line search is performed. Algorithms from `LineSearches.jl` can be + used here directly, and they will be converted to the correct `LineSearch`. !!! warning @@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems. ad::AD linsolve precs + linesearch end function set_ad(alg::GaussNewton{CJ}, ad) where {CJ} - return GaussNewton{CJ}(ad, alg.linsolve, alg.precs) + return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch) end function GaussNewton(; concrete_jac = nothing, linsolve = nothing, - precs = DEFAULT_PRECS, adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) + linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) end @concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} @@ -78,6 +83,7 @@ end stats::NLStats tc_cache_1 tc_cache_2 + ls_cache end function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton, @@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf, linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, - abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2) + abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2, + init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip))) end function perform_step!(cache::GaussNewtonCache{true}) @@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true}) linu = _vec(du), p, reltol = cache.abstol) end cache.linsolve = linres.cache - @. u = u - du + α = perform_linesearch!(cache.ls_cache, u, du) + _axpy!(-α, du, u) f(cache.fu_new, u, p) check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev) @@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false}) end cache.linsolve = linres.cache end - cache.u = @. u - cache.du # `u` might not support mutation + α = perform_linesearch!(cache.ls_cache, u, cache.du) + cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu_new = f(cache.u, p) check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev) diff --git a/src/linesearch.jl b/src/linesearch.jl index 760f67769..833361607 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -122,6 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe end function g!(u, fu) + # FIXME: Upstream patch to allow non-square Jacobians op = VecJac((args...) -> f(args..., p), u; autodiff) if iip mul!(g₀, op, fu) diff --git a/src/raphson.jl b/src/raphson.jl index 1b75d231e..a28dec699 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,5 +1,5 @@ """ - NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, + NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) An advanced NewtonRaphson implementation with support for efficient handling of sparse diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index c7a02dc58..f4c9f9f7d 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] -solvers = [ - GaussNewton(), - GaussNewton(; linsolve = LUFactorization()), - LevenbergMarquardt(), - LevenbergMarquardt(; linsolve = LUFactorization()), - LeastSquaresOptimJL(:lm), - LeastSquaresOptimJL(:dogleg), -] +solvers = vec(Any[GaussNewton(; linsolve, linesearch) + for linsolve in [nothing, LUFactorization()], +linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]]) +append!(solvers, + [ + LevenbergMarquardt(), + LevenbergMarquardt(; linsolve = LUFactorization()), + LeastSquaresOptimJL(:lm), + LeastSquaresOptimJL(:dogleg), + ]) for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)