Skip to content

Commit

Permalink
Allow Levenberg to work with NonlinearLeastSquaresProblem
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2023
1 parent d0d3db4 commit f0fdcc6
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ LineSearches = "7"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1.97"
SciMLBase = "2"
SimpleNonlinearSolve = "0.1"
SparseDiffTools = "2.6"
StaticArraysCore = "1.4"
Expand Down
13 changes: 9 additions & 4 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
while not_terminated(cache)
perform_step!(cache)
cache.stats.nsteps += 1
end
Expand All @@ -50,7 +55,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
cache.retcode, cache.stats)
end

Expand Down
26 changes: 17 additions & 9 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ numerically-difficult nonlinear systems.
where `J` is the Jacobian. It is suggested by
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
!!! warning
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
Support for the OOP version is planned!
"""
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
Expand Down Expand Up @@ -135,11 +140,14 @@ end
loss_old::lossType
make_new_J::Bool
fu_tmp
u_tmp
Jv
mat_tmp::jType
stats::NLStats
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
Expand All @@ -166,21 +174,21 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
end

loss = internalnorm(fu1)
JᵀJ = zero(J)
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
v = zero(u)
a = zero(u)
tmp_vec = zero(u)
v_old = zero(u)
δ = zero(u)
make_new_J = true
fu_tmp = zero(fu1)
mat_tmp = zero(J)
mat_tmp = zero(JᵀJ)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
mat_tmp, NLStats(1, 0, 0, 0, 0))
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::LevenbergMarquardtCache{true})
Expand All @@ -200,10 +208,10 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.fu_tmp
mul!(cache.fu_tmp, J', fu1)
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(cache.u_tmp, J', fu1)
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.v = -cache.du
Expand All @@ -213,8 +221,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
f(cache.fu_tmp, u .+ h .* v, p)

# The following lines do: cache.a = -J \ cache.fu_tmp
mul!(cache.du, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.du)
mul!(cache.Jv, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(cache.fu_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
Expand Down
28 changes: 16 additions & 12 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,22 @@ function perform_step!(cache::NewtonRaphsonCache{false})
return nothing
end

function SciMLBase.solve!(cache::NewtonRaphsonCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
cache.retcode = ReturnCode.Success
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
24 changes: 9 additions & 15 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ for large-scale and numerically-difficult nonlinear systems.
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
!!! warning
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
Support for the OOP version is planned!
"""
@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
Expand Down Expand Up @@ -552,22 +557,11 @@ function jvp!(cache::TrustRegionCache{true})
return g
end

function SciMLBase.solve!(cache::TrustRegionCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
cache.shrink_counter < cache.alg.max_shrink_times
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu; cache.retcode,
cache.stats)
function not_terminated(cache::TrustRegionCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
cache.shrink_counter < cache.alg.max_shrink_times
end
get_fu(cache::TrustRegionCache) = cache.fu

function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ _maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
_maybe_mutable(x, _) = x

# Helper function to get value of `f(u, p)`
function evaluate_f(prob::NonlinearProblem{uType, iip}, u) where {uType, iip}
function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip}
@unpack f, u0, p = prob
if iip
fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
Expand Down

0 comments on commit f0fdcc6

Please sign in to comment.