Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gauss Newton with Line Search #268

Merged
merged 7 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Reexport = "0.2, 1"
SciMLBase = "2.4"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseDiffTools = "2.9"
SparseDiffTools = "2.11"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
Expand Down
23 changes: 16 additions & 7 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)

An advanced GaussNewton implementation with support for efficient handling of sparse
Expand Down Expand Up @@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.

!!! warning

Expand All @@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems.
ad::AD
linsolve
precs
linesearch
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -78,6 +83,7 @@ end
stats::NLStats
tc_cache_1
tc_cache_2
ls_cache
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
Expand Down Expand Up @@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2)
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
end

function perform_step!(cache::GaussNewtonCache{true})
Expand All @@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true})
linu = _vec(du), p, reltol = cache.abstol)
end
cache.linsolve = linres.cache
@. u = u - du
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(cache.fu_new, u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down Expand Up @@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false})
end
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
α = perform_linesearch!(cache.ls_cache, u, cache.du)
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down
2 changes: 1 addition & 1 deletion src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
end

function g!(u, fu)
op = VecJac((args...) -> f(args..., p), u; autodiff)
op = VecJac(f, u, p; fu = fu1, autodiff)
if iip
mul!(g₀, op, fu)
return g₀
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)

An advanced NewtonRaphson implementation with support for efficient handling of sparse
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function __get_concrete_algorithm(alg, prob)
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
else
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
tag = NonlinearSolveTag())
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
end
return set_ad(alg, ad)
end
Expand Down
18 changes: 10 additions & 8 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = LUFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
for linsolve in [nothing, LUFactorization()],
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
append!(solvers,
[
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
])

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down
Loading