Skip to content

Commit

Permalink
Add termination condition to gaussnewton and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Oct 10, 2023
1 parent 20cdb37 commit eefbca3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 21 deletions.
61 changes: 49 additions & 12 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct GaussNewton{CJ, AD, TC} <: AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
termination_condition::TC
end

function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
abstol = nothing,
reltol = nothing), adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs,
termination_condition)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
fu_new
Expand All @@ -67,12 +75,15 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
stats::NLStats
tc_storage
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -85,14 +96,28 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::G
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_with_JᵀJ = Val(true))

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, eltype(u))
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol,
prob, NLStats(1, 0, 0, 0, 0), storage)
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)

Expand All @@ -103,9 +128,15 @@ function perform_step!(cache::GaussNewtonCache{true})
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)

@. u_prev = u
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -115,7 +146,10 @@ function perform_step!(cache::GaussNewtonCache{true})
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)

Expand All @@ -132,7 +166,10 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand Down
14 changes: 5 additions & 9 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
abstol = nothing,
reltol = nothing),
adkwargs...)
Expand Down Expand Up @@ -149,7 +149,8 @@ end
tc_storage
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
Expand Down Expand Up @@ -200,7 +201,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)

end

function perform_step!(cache::LevenbergMarquardtCache{true})
Expand Down Expand Up @@ -261,11 +261,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
if (1 - β)^b_uphill * loss loss_old
# Accept step.
cache.u .+= δ
if termination_condition(cache.fu_tmp,
cache.u,
u_prev,
cache.abstol,
cache.reltol)
if loss < cache.abstol
cache.force_stop = true
return nothing
end
Expand Down Expand Up @@ -305,7 +301,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = false
cache.stats.njacs += 1
end

@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

cache.mat_tmp = JᵀJ + λ * DᵀD
Expand Down

0 comments on commit eefbca3

Please sign in to comment.