diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 7c94ec4f3..41e4f1775 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -35,22 +35,30 @@ for large-scale and numerically-difficult nonlinear least squares problems. Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian construction. This will be fixed in the near future. """ -@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct GaussNewton{CJ, AD, TC} <: AbstractNewtonAlgorithm{CJ, AD, TC} ad::AD linsolve precs + termination_condition::TC end function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(), - precs = DEFAULT_PRECS, adkwargs...) + precs = DEFAULT_PRECS, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm; + abstol = nothing, + reltol = nothing), adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, + linsolve, + precs, + termination_condition) end @concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u + u_prev fu1 fu2 fu_new @@ -67,12 +75,15 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats + tc_storage end function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::GaussNewton, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -85,14 +96,28 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::G uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip); linsolve_with_JᵀJ = Val(true)) - return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, - JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, - prob, NLStats(1, 0, 0, 0, 0)) + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + atol = _get_tolerance(abstol, tc.abstol, eltype(u)) + rtol = _get_tolerance(reltol, tc.reltol, eltype(u)) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf, + linsolve, J, + JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, + prob, NLStats(1, 0, 0, 0, 0), storage) end function perform_step!(cache::GaussNewtonCache{true}) - @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache + @unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache jacobian!!(J, cache) + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + mul!(JᵀJ, J', J) mul!(Jᵀf, J', fu1) @@ -103,9 +128,15 @@ function perform_step!(cache::GaussNewtonCache{true}) @. u = u - du f(cache.fu_new, u, p) - (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || - cache.internalnorm(cache.fu_new) < cache.abstol) && + (termination_condition(cache.fu_new .- cache.fu1, + cache.u, + u_prev, + cache.abstol, + cache.reltol) || + termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) && (cache.force_stop = true) + + @. u_prev = u cache.fu1 .= cache.fu_new cache.stats.nf += 1 cache.stats.njacs += 1 @@ -115,7 +146,10 @@ function perform_step!(cache::GaussNewtonCache{true}) end function perform_step!(cache::GaussNewtonCache{false}) - @unpack u, fu1, f, p, alg, linsolve = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve = cache + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) cache.J = jacobian!!(cache.J, cache) @@ -132,7 +166,10 @@ function perform_step!(cache::GaussNewtonCache{false}) cache.u = @. u - cache.du # `u` might not support mutation cache.fu_new = f(cache.u, p) - (cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true) + termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + cache.u_prev = @. cache.u cache.fu1 = cache.fu_new cache.stats.nf += 1 cache.stats.njacs += 1 diff --git a/src/levenberg.jl b/src/levenberg.jl index b2cd6c6d0..e19f67f09 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -92,7 +92,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1, α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm; abstol = nothing, reltol = nothing), adkwargs...) @@ -149,7 +149,8 @@ end tc_storage end -function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt, +function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, + NonlinearLeastSquaresProblem{uType, iip}}, alg::LevenbergMarquardt, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @@ -200,7 +201,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic, b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp, zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage) - end function perform_step!(cache::LevenbergMarquardtCache{true}) @@ -261,11 +261,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) if (1 - β)^b_uphill * loss ≤ loss_old # Accept step. cache.u .+= δ - if termination_condition(cache.fu_tmp, - cache.u, - u_prev, - cache.abstol, - cache.reltol) + if loss < cache.abstol cache.force_stop = true return nothing end @@ -305,7 +301,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.make_new_J = false cache.stats.njacs += 1 end - + @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache cache.mat_tmp = JᵀJ + λ * DᵀD