diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 7917c9e83..67d19c2fb 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -1,8 +1,39 @@ +""" + NonlinearSafeTerminationReturnCode + +Return Codes for the safe nonlinear termination conditions. +""" @enumx NonlinearSafeTerminationReturnCode begin + """ + NonlinearSafeTerminationReturnCode.Success + + Termination Condition was satisfied! + """ Success + """ + NonlinearSafeTerminationReturnCode.Default + + Default Return Code. Used for type stability and conveys no additional information! + """ Default + """ + NonlinearSafeTerminationReturnCode.PatienceTermination + + Terminate if there has been no improvement for the last `patience_steps`. + """ PatienceTermination + """ + NonlinearSafeTerminationReturnCode.ProtectiveTermination + + Terminate if the objective value increased by this factor wrt initial objective or the + value diverged. + """ ProtectiveTermination + """ + NonlinearSafeTerminationReturnCode.Failure + + Termination Condition was not satisfied! + """ Failure end @@ -12,34 +43,149 @@ abstract type AbstractSafeBestNonlinearTerminationMode <: AbstractSafeNonlinearTerminationMode end # TODO: Add a mode where the user can pass in custom termination criteria function -for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode, - :NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode, - :AbsNormTerminationMode) - @eval begin - struct $(mode) <: AbstractNonlinearTerminationMode end - end + +""" + SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. + +The default used in SteadyStateDiffEq.jl! Not recommended for large problems, since the +convergence criteria is very strict and never reliably satisfied for most problems. +""" +struct SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode end + +""" + SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. Or check that the value of the current and previous state is within the specified +tolerances. + +The default used in SimpleNonlinearSolve.jl! Not recommended for large problems, since the +convergence criteria is very strict and never reliably satisfied for most problems. +""" +struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + NormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +or ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +""" +struct NormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``. +""" +struct RelTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelNormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +""" +struct RelNormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + AbsTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``. +""" +struct AbsTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + AbsNormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``. +""" +struct AbsNormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelSafeTerminationMode <: AbstractSafeNonlinearTerminationMode + +Essentially [`RelNormTerminationMode`](@ref) + terminate if there has been no improvement +for the last `patience_steps` + terminate if the solution blows up (diverges). + +## Constructor + +```julia +RelSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end -for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode) - @eval begin - Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode - protective_threshold::T1 = 1000 - patience_steps::Int = 30 - patience_objective_multiplier::T2 = 3 - min_max_factor::T3 = 1.3 - end - end +@doc doc""" + AbsSafeTerminationMode <: AbstractSafeNonlinearTerminationMode + +Essentially [`AbsNormTerminationMode`](@ref) + terminate if there has been no improvement +for the last `patience_steps` + terminate if the solution blows up (diverges). + +## Constructor + +```julia +AbsSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end -for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) - @eval begin - Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode - protective_threshold::T1 = 1000 - patience_steps::Int = 100 - patience_objective_multiplier::T2 = 3 - min_max_factor::T3 = 1.3 - end - end +@doc doc""" + RelSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode + +Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found so far. + +## Constructor + +```julia +RelSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 +end + +@doc doc""" + AbsSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode + +Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found so far. + +## Constructor + +```julia +AbsSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end mutable struct NonlinearTerminationModeCache{uType, T, @@ -78,8 +224,8 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T} end function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, - mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, - kwargs...) where {T <: Number} + mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, + kwargs...) where {T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) TT = typeof(abstol) @@ -113,12 +259,12 @@ end (cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev) function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, - u, uprev) + u, uprev) return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) end function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, - du, u, uprev) + du, u, uprev) if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode objective = NONLINEARSOLVE_DEFAULT_NORM(du) criteria = cache.abstol @@ -130,7 +276,7 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Protective Break if isinf(objective) || isnan(objective) || - (objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du)) + (objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du)) cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination return true end @@ -172,11 +318,11 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi end function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, - reltol) + reltol) return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) end function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, - reltol) + reltol) return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) || isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) end @@ -188,7 +334,7 @@ function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) end function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, - RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ) end @@ -196,7 +342,7 @@ function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol return all(abs.(duₙ) .≤ abstol) end function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, - AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ abstol end @@ -242,8 +388,8 @@ mutable struct NLSolveSafeTerminationResult{T, uType} end function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64, - best_objective_value_iteration = 0, - return_code = NLSolveSafeTerminationReturnCode.Failure) + best_objective_value_iteration = 0, + return_code = NLSolveSafeTerminationReturnCode.Failure) u = u !== nothing ? copy(u) : u Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveSafeTerminationResult) @@ -330,9 +476,9 @@ get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode # Don't specify `mode` since the defaults would depend on the package function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, - protective_threshold = 1e3, patience_steps::Int = 30, - patience_objective_multiplier = 3, - min_max_factor = 1.3) where {T} + protective_threshold = 1e3, patience_steps::Int = 30, + patience_objective_multiplier = 3, + min_max_factor = 1.3) where {T} Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveTerminationCondition) @assert mode ∈ instances(NLSolveTerminationMode.T) @@ -346,9 +492,9 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, end function (cond::NLSolveTerminationCondition)(storage::Union{ - NLSolveSafeTerminationResult, - Nothing, -}) + NLSolveSafeTerminationResult, + Nothing, + }) mode = get_termination_mode(cond) # We need both the dispatches to support solvers that don't use the integrator # interface like SimpleNonlinearSolve @@ -438,7 +584,7 @@ end # Convergence Criteria @inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode}, - abstol = cond.abstol, reltol = cond.reltol) where {mode} + abstol = cond.abstol, reltol = cond.reltol) where {mode} return _has_converged(du, u, uprev, mode, abstol, reltol) end