diff --git a/Project.toml b/Project.toml index 726dedac7..3080b7178 100644 --- a/Project.toml +++ b/Project.toml @@ -79,6 +79,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" [targets] -test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"] +test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"] diff --git a/src/broyden.jl b/src/broyden.jl index 6be29a77b..22a1a9cc8 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -31,6 +31,7 @@ end f alg u + u_prev du fu fu2 @@ -46,17 +47,21 @@ end internalnorm retcode::ReturnCode.T abstol + reltol reset_tolerance reset_check prob stats::NLStats lscache + termination_condition + tc_storage end get_fu(cache::GeneralBroydenCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -65,15 +70,29 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) : alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance - return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu), + + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu), zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0, - alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance, + alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, + reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::GeneralBroydenCache{true}) - @unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache + @unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) T = eltype(u) mul!(_vec(du), J⁻¹, -_vec(fu)) @@ -81,7 +100,8 @@ function perform_step!(cache::GeneralBroydenCache{true}) _axpy!(α, du, u) f(fu2, u, p) - cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -106,12 +126,16 @@ function perform_step!(cache::GeneralBroydenCache{true}) mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1) end fu .= fu2 + @. u_prev = u return nothing end function perform_step!(cache::GeneralBroydenCache{false}) - @unpack f, p = cache + @unpack f, p, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(cache.u) cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu)) @@ -119,7 +143,8 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.u = cache.u .+ α * cache.du cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂ end cache.fu = cache.fu2 + cache.u_prev = @. cache.u return nothing end function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca cache.u = u0 cache.fu = cache.f(cache.u, p) end + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/dfsane.jl b/src/dfsane.jl index f5bf69eca..d88f18db0 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -88,12 +88,16 @@ end internalnorm retcode::SciMLBase.ReturnCode.T abstol + reltol prob stats::NLStats + termination_condition + tc_storage end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0) @@ -122,14 +126,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args. f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁ ℋ = fill(f₍ₙₒᵣₘ₎ₙ₋₁, M) + + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + T) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters, - internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0)) + internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), + termination_condition, storage) end function perform_step!(cache::DFSaneCache{true}) - @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache + @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache + termination_condition = cache.termination_condition(tc_storage) f = (dx, x) -> cache.prob.f(dx, x, cache.p) T = eltype(cache.uₙ) @@ -174,7 +191,7 @@ function perform_step!(cache::DFSaneCache{true}) f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ end - if cache.internalnorm(cache.fuₙ) < cache.abstol + if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol) cache.force_stop = true end @@ -205,8 +222,9 @@ function perform_step!(cache::DFSaneCache{true}) end function perform_step!(cache::DFSaneCache{false}) - @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache + @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache + termination_condition = cache.termination_condition(tc_storage) f = x -> cache.prob.f(x, cache.p) T = eltype(cache.uₙ) @@ -249,7 +267,7 @@ function perform_step!(cache::DFSaneCache{false}) f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ end - if cache.internalnorm(cache.fuₙ) < cache.abstol + if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol) cache.force_stop = true end @@ -296,7 +314,9 @@ function SciMLBase.solve!(cache::DFSaneCache) end function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.uₙ, u0) @@ -317,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p T = eltype(cache.uₙ) cache.σₙ = T(cache.alg.σ_1) + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index a6ec1ae9b..42155072a 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -49,13 +49,16 @@ end function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, + linsolve, + precs) end @concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u + u_prev fu1 fu2 fu_new @@ -72,12 +75,17 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats + tc_storage + termination_condition end function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, + internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -101,15 +109,29 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: JᵀJ, Jᵀf = nothing, nothing end - return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u); mode = NLSolveTerminationMode.AbsNorm) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf, + linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, - prob, NLStats(1, 0, 0, 0, 0)) + reltol, + prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition) end function perform_step!(cache::GaussNewtonCache{true}) - @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache + @unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache jacobian!!(J, cache) + termination_condition = cache.termination_condition(tc_storage) + if JᵀJ !== nothing __matmul!(JᵀJ, J', J) __matmul!(Jᵀf, J', fu1) @@ -127,9 +149,15 @@ function perform_step!(cache::GaussNewtonCache{true}) @. u = u - du f(cache.fu_new, u, p) - (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || - cache.internalnorm(cache.fu_new) < cache.abstol) && + (termination_condition(cache.fu_new .- cache.fu1, + cache.u, + u_prev, + cache.abstol, + cache.reltol) || + termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) && (cache.force_stop = true) + + @. u_prev = u cache.fu1 .= cache.fu_new cache.stats.nf += 1 cache.stats.njacs += 1 @@ -139,7 +167,9 @@ function perform_step!(cache::GaussNewtonCache{true}) end function perform_step!(cache::GaussNewtonCache{false}) - @unpack u, fu1, f, p, alg, linsolve = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) cache.J = jacobian!!(cache.J, cache) @@ -164,7 +194,10 @@ function perform_step!(cache::GaussNewtonCache{false}) cache.u = @. u - cache.du # `u` might not support mutation cache.fu_new = f(cache.u, p) - (cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true) + termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + cache.u_prev = @. cache.u cache.fu1 = cache.fu_new cache.stats.nf += 1 cache.stats.njacs += 1 @@ -174,7 +207,9 @@ function perform_step!(cache::GaussNewtonCache{false}) end function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -184,7 +219,14 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache cache.u = u0 cache.fu1 = cache.f(cache.u, p) end + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/klement.jl b/src/klement.jl index a16ed2873..8fc44be59 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -41,6 +41,7 @@ end f alg u + u_prev fu fu2 du @@ -57,15 +58,19 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats lscache + termination_condition + tc_storage end get_fu(cache::GeneralKlementCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -84,16 +89,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg) end - return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve, + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve, J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false, - maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, + NLStats(1, 0, 0, 0, 0), + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::GeneralKlementCache{true}) - @unpack u, fu, f, p, alg, J, linsolve, du = cache + @unpack u, u_prev, fu, f, p, alg, J, linsolve, du, tc_storage = cache T = eltype(J) + termination_condition = cache.termination_condition(tc_storage) + singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) if singular @@ -118,7 +137,8 @@ function perform_step!(cache::GeneralKlementCache{true}) _axpy!(α, du, u) f(cache.fu2, u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -138,13 +158,17 @@ function perform_step!(cache::GeneralKlementCache{true}) mul!(cache.J_cache2, cache.J_cache, J) J .+= cache.J_cache2 + @. u_prev = u cache.fu .= cache.fu2 return nothing end function perform_step!(cache::GeneralKlementCache{false}) - @unpack fu, f, p, alg, J, linsolve = cache + @unpack fu, f, p, alg, J, linsolve, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(J) singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) @@ -174,7 +198,10 @@ function perform_step!(cache::GeneralKlementCache{false}) cache.u = @. cache.u + α * cache.du # `u` might not support mutation cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + cache.u_prev = @. cache.u cache.stats.nf += 1 cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -198,7 +225,9 @@ function perform_step!(cache::GeneralKlementCache{false}) end function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -208,7 +237,14 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca cache.u = u0 cache.fu = cache.f(cache.u, p) end + + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/lbroyden.jl b/src/lbroyden.jl index d045d0b20..db4353b41 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -34,6 +34,7 @@ end f alg u + u_prev du fu fu2 @@ -53,17 +54,21 @@ end internalnorm retcode::ReturnCode.T abstol + reltol reset_tolerance reset_check prob stats::NLStats lscache + termination_condition + tc_storage end get_fu(cache::LimitedMemoryBroydenCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemoryBroyden, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -80,23 +85,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemory reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) : alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance - return LimitedMemoryBroydenCache{iip}(f, alg, u, du, fu, zero(fu), + + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return LimitedMemoryBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu), zero(fu), p, U, Vᵀ, similar(u, threshold), similar(u, 1, threshold), zero(u), zero(u), false, 0, 0, alg.max_resets, maxiters, internalnorm, - ReturnCode.Default, abstol, reset_tolerance, reset_check, prob, + ReturnCode.Default, abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::LimitedMemoryBroydenCache{true}) - @unpack f, p, du, u = cache + @unpack f, p, du, u, tc_storage = cache T = eltype(u) + termination_condition = cache.termination_condition(tc_storage) + α = perform_linesearch!(cache.lscache, u, du) _axpy!(α, du, u) f(cache.fu2, u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -138,20 +158,25 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true}) cache.iterations_since_reset += 1 end + cache.u_prev .= cache.u cache.fu .= cache.fu2 return nothing end function perform_step!(cache::LimitedMemoryBroydenCache{false}) - @unpack f, p = cache + @unpack f, p, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(cache.u) α = perform_linesearch!(cache.lscache, cache.u, cache.du) cache.u = cache.u .+ α * cache.du cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -194,6 +219,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false}) cache.iterations_since_reset += 1 end + cache.u_prev = @. cache.u cache.fu = cache.fu2 return nothing diff --git a/src/levenberg.jl b/src/levenberg.jl index b4132ec9e..1fe2feed6 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -79,7 +79,8 @@ routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more [this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in `DᵀD` to prevent the damping from being too small. Defaults to `1e-8`. """ -@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct LevenbergMarquardt{CJ, AD, T} <: + AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs @@ -114,6 +115,7 @@ end f alg u + u_prev fu1 fu2 du @@ -127,6 +129,7 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob DᵀD JᵀJ @@ -153,11 +156,15 @@ end rhs_tmp J² stats::NLStats + termination_condition + tc_storage end function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -178,6 +185,11 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, v = similar(du) end + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u); mode = NLSolveTerminationMode.AbsNorm) + λ = convert(eltype(u), alg.damping_initial) λ_factor = convert(eltype(u), alg.damping_increase_factor) damping_increase_factor = convert(eltype(u), alg.damping_increase_factor) @@ -203,6 +215,11 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, make_new_J = true fu_tmp = zero(fu1) + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + if _unwrap_val(linsolve_with_JᵀJ) mat_tmp = zero(JᵀJ) rhs_tmp = nothing @@ -215,16 +232,22 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg) end - return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, fu1, + return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, copy(u), + fu1, fu2, du, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, + DᵀD, JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic, b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp, - zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0)) + zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0), + termination_condition, storage) end function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fastls} - @unpack fu1, f, make_new_J = cache + @unpack fu1, f, make_new_J, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + if iszero(fu1) cache.force_stop = true return nothing @@ -243,7 +266,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fast cache.make_new_J = false cache.stats.njacs += 1 end - @unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache + @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache # Usual Levenberg-Marquardt step ("velocity"). # The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp @@ -300,7 +323,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fast if (1 - β)^b_uphill * loss ≤ loss_old # Accept step. cache.u .+= δ - if loss < cache.abstol + if termination_condition(cache.fu_tmp, + cache.u, + u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true return nothing end @@ -312,13 +339,17 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fast cache.make_new_J = true end end + @. u_prev = u cache.λ *= cache.λ_factor cache.λ_factor = cache.damping_increase_factor return nothing end function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fastls} - @unpack fu1, f, make_new_J = cache + @unpack fu1, f, make_new_J, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + if iszero(fu1) cache.force_stop = true return nothing @@ -340,7 +371,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas cache.make_new_J = false cache.stats.njacs += 1 end - @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache + + @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache # Usual Levenberg-Marquardt step ("velocity"). if fastls @@ -393,7 +425,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas if (1 - β)^b_uphill * loss ≤ loss_old # Accept step. cache.u += δ - if loss < cache.abstol + if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol) cache.force_stop = true return nothing end @@ -405,6 +437,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas cache.make_new_J = true end end + cache.u_prev = @. cache.u cache.λ *= cache.λ_factor cache.λ_factor = cache.damping_increase_factor return nothing diff --git a/src/pseudotransient.jl b/src/pseudotransient.jl index a041c258c..c5871b12f 100644 --- a/src/pseudotransient.jl +++ b/src/pseudotransient.jl @@ -3,9 +3,9 @@ precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...) An implementation of PseudoTransient method that is used to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping to -integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and +integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and gain a rapid convergence. This implementation specifically uses "switched evolution relaxation" SER method. For detail information about the time-stepping and algorithm, -please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations, +please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations, SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X) ### Keyword Arguments @@ -27,7 +27,7 @@ SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S1064 preconditioners. For more information on specifying preconditioners for LinearSolve algorithms, consult the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). - - `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small, + - `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small, you are going to need more iterations to converge but it can be more stable. """ @concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} @@ -52,6 +52,7 @@ end f alg u + u_prev fu1 fu2 du @@ -67,15 +68,19 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats + termination_condition + tc_storage end isinplace(::PseudoTransientCache{iip}) where {iip} = iip function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @@ -93,16 +98,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransi alpha = convert(eltype(u), alg.alpha_initial) res_norm = internalnorm(fu1) - return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf, + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return PseudoTransientCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, alpha, res_norm, + uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, - prob, NLStats(1, 0, 0, 0, 0)) + reltol, + prob, NLStats(1, 0, 0, 0, 0), termination_condition, storage) end function perform_step!(cache::PseudoTransientCache{true}) - @unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache + @unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, alpha, tc_storage = cache jacobian!!(J, cache) J_new = J - (1 / alpha) * I + termination_condition = cache.termination_condition(tc_storage) + # u = u - J \ fu linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du), p, reltol = cache.abstol) @@ -114,7 +133,10 @@ function perform_step!(cache::PseudoTransientCache{true}) cache.alpha *= cache.res_norm / new_norm cache.res_norm = new_norm - new_norm < cache.abstol && (cache.force_stop = true) + termination_condition(fu1, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + @. u_prev = u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -123,7 +145,10 @@ function perform_step!(cache::PseudoTransientCache{true}) end function perform_step!(cache::PseudoTransientCache{false}) - @unpack u, fu1, f, p, alg, linsolve, alpha = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve, alpha, tc_storage = cache + + tc_storage = cache.tc_storage + termination_condition = cache.termination_condition(tc_storage) cache.J = jacobian!!(cache.J, cache) # u = u - J \ fu @@ -141,7 +166,9 @@ function perform_step!(cache::PseudoTransientCache{false}) new_norm = cache.internalnorm(fu1) cache.alpha *= cache.res_norm / new_norm cache.res_norm = new_norm - new_norm < cache.abstol && (cache.force_stop = true) + termination_condition(fu1, cache.u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + cache.u_prev = @. cache.u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -167,7 +194,9 @@ end function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p, alpha_new, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -177,9 +206,17 @@ function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = c cache.u = u0 cache.fu1 = cache.f(cache.u, p) end + + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.alpha = convert(eltype(cache.u), alpha_new) cache.res_norm = cache.internalnorm(cache.fu1) cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/raphson.jl b/src/raphson.jl index 4a2b53404..a34d860ce 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -30,7 +30,8 @@ for large-scale and numerically-difficult nonlinear systems. which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. """ -@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct NewtonRaphson{CJ, AD} <: + AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs @@ -45,13 +46,17 @@ function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) - return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) + return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, + linsolve, + precs, + linesearch) end @concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u + u_prev fu1 fu2 du @@ -65,13 +70,18 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats lscache + termination_condition + tc_storage end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -80,16 +90,29 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs) - return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip))) + init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), + termination_condition, storage) end function perform_step!(cache::NewtonRaphsonCache{true}) - @unpack u, fu1, f, p, alg, J, linsolve, du = cache + @unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, tc_storage = cache jacobian!!(J, cache) + termination_condition = cache.termination_condition(tc_storage) + # u = u - J \ fu linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du), p, reltol = cache.abstol) @@ -100,7 +123,10 @@ function perform_step!(cache::NewtonRaphsonCache{true}) _axpy!(-α, du, u) f(cache.fu1, u, p) - cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + @. u_prev = u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -109,7 +135,9 @@ function perform_step!(cache::NewtonRaphsonCache{true}) end function perform_step!(cache::NewtonRaphsonCache{false}) - @unpack u, fu1, f, p, alg, linsolve = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) cache.J = jacobian!!(cache.J, cache) # u = u - J \ fu @@ -126,7 +154,10 @@ function perform_step!(cache::NewtonRaphsonCache{false}) cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu1 = f(cache.u, p) - cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + cache.u_prev = @. cache.u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -135,7 +166,9 @@ function perform_step!(cache::NewtonRaphsonCache{false}) end function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -145,7 +178,14 @@ function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cac cache.u = u0 cache.fu1 = cache.f(cache.u, p) end + + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 8e7118cc6..3b14d0a38 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -147,7 +147,8 @@ for large-scale and numerically-difficult nonlinear systems. `linsolve` and `precs` are used exclusively for the inplace version of the algorithm. Support for the OOP version is planned! """ -@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct TrustRegion{CJ, AD, MTR} <: + AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs @@ -200,6 +201,7 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob radius_update_scheme::RadiusUpdateSchemes.T trust_r::trustType @@ -227,10 +229,14 @@ end p4::floatType ϵ::floatType stats::NLStats + tc_storage + termination_condition end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-8, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -332,13 +338,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, initial_trust_radius = convert(trustType, 1.0) end + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new, H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ, - NLStats(1, 0, 0, 0, 0)) + NLStats(1, 0, 0, 0, 0), storage, termination_condition) end function perform_step!(cache::TrustRegionCache{true}) @@ -414,7 +430,10 @@ function retrospective_step!(cache::TrustRegionCache) end function trust_region_step!(cache::TrustRegionCache) - @unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme = cache + @unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + cache.loss_new = get_loss(fu_new) # Compute the ratio of the actual reduction to the predicted reduction. @@ -444,8 +463,11 @@ function trust_region_step!(cache::TrustRegionCache) # No need to make a new J, no step was taken, so we try again with a smaller trust_r cache.make_new_J = false end - - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + if iszero(cache.fu) || termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true end @@ -513,7 +535,12 @@ function trust_region_step!(cache::TrustRegionCache) cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(du) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -538,7 +565,12 @@ function trust_region_step!(cache::TrustRegionCache) @unpack p1 = cache cache.trust_r = p1 * cache.internalnorm(jvp!(cache)) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -562,7 +594,12 @@ function trust_region_step!(cache::TrustRegionCache) @unpack p1 = cache cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -580,7 +617,11 @@ function trust_region_step!(cache::TrustRegionCache) cache.trust_r *= cache.p2 cache.shrink_counter += 1 end - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + if iszero(cache.fu) || termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true end end @@ -683,7 +724,9 @@ end get_fu(cache::TrustRegionCache) = cache.fu function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -693,7 +736,14 @@ function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache cache.u = u0 cache.fu = cache.f(cache.u, p) end + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/utils.jl b/src/utils.jl index a4231cf9a..7aa19f6ca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -210,6 +210,63 @@ function __get_concrete_algorithm(alg, prob) return set_ad(alg, ad) end +function _get_tolerance(η, tc_η, ::Type{T}) where {T} + fallback_η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return T(ifelse(η !== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η))) +end + +function _init_termination_elements(abstol, + reltol, + termination_condition, + ::Type{T}; mode = NLSolveTerminationMode.NLSolveDefault) where {T} + if termination_condition !== nothing + abstol !== nothing ? + (abstol != termination_condition.abstol ? + error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") : + nothing) : nothing + reltol !== nothing ? + (reltol != termination_condition.abstol ? + error("Incompatible relative tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") : + nothing) : nothing + abstol = _get_tolerance(abstol, termination_condition.abstol, T) + reltol = _get_tolerance(reltol, termination_condition.reltol, T) + return abstol, reltol, termination_condition + else + abstol = _get_tolerance(abstol, nothing, T) + reltol = _get_tolerance(reltol, nothing, T) + termination_condition = NLSolveTerminationCondition(mode; + abstol, + reltol) + return abstol, reltol, termination_condition + end +end + +function _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) + if termination_condition != cache.termination_condition + if abstol != cache.abstol + if abstol != termination_condition.abstol + error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") + end + end + + if reltol != cache.reltol + if reltol != termination_condition.reltol + error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") + end + end + termination_condition + else + # Build the termination_condition with new abstol and reltol + return NLSolveTerminationCondition{ + DiffEqBase.get_termination_mode(termination_condition), + eltype(abstol), + typeof(termination_condition.safe_termination_options), + }(abstol, + reltol, + termination_condition.safe_termination_options) + end +end + __init_identity_jacobian(u::Number, _) = u function __init_identity_jacobian(u, fu) return convert(parameterless_type(_mutable(u)), diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index bd2b932dc..091088ab2 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -75,7 +75,7 @@ end alg_ops = (DFSane(),) broken_tests = Dict(alg => Int[] for alg in alg_ops) - broken_tests[alg_ops[1]] = [1, 2, 3, 5, 6, 8, 12, 13, 14, 21] + broken_tests[alg_ops[1]] = [1, 2, 3, 5, 6, 21] test_on_library(problems, dicts, alg_ops, broken_tests) end diff --git a/test/basictests.jl b/test/basictests.jl index 1606270fc..02f100e71 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,5 +1,5 @@ using BenchmarkTools, LinearSolve, NonlinearSolve, StaticArrays, Random, LinearAlgebra, - Test, ForwardDiff, Zygote, Enzyme, SparseDiffTools + Test, ForwardDiff, Zygote, Enzyme, SparseDiffTools, DiffEqBase _nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x)) @@ -72,7 +72,7 @@ end end @testset "[OOP] [Immutable AD]" begin - for p in 1.0:0.1:100.0 + for p in [1.0, 100.0] @test begin res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) res_true = sqrt(p) @@ -123,6 +123,20 @@ end probN = NonlinearProblem(quadratic_f, u0, 2.0) @test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0)) end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, NewtonRaphson(); termination_condition).u .≈ sqrt(2.0)) + end end # --- TrustRegion tests --- @@ -281,6 +295,20 @@ end @test sol_iip.u ≈ sol_oop.u end end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, TrustRegion(); termination_condition).u .≈ sqrt(2.0)) + end end # --- LevenbergMarquardt tests --- @@ -391,6 +419,20 @@ end @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, LevenbergMarquardt(); termination_condition).u .≈ sqrt(2.0)) + end end # --- DFSane tests --- @@ -543,6 +585,20 @@ end @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, DFSane(); termination_condition).u .≈ sqrt(2.0)) + end end # --- PseudoTransient tests --- @@ -664,6 +720,22 @@ end sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10) @test all(abs.(newton_fails(sol.u, p)) .< 1e-10) end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, + PseudoTransient(; alpha_initial = 10.0); + termination_condition).u .≈ sqrt(2.0)) + end end # --- GeneralBroyden tests ---