From d63e00b88bf41473979c0d500664d35e04c0b5bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Oct 2023 22:42:29 -0400 Subject: [PATCH 1/8] Rework the Termination Condition API to be type stable --- src/termination_conditions.jl | 191 ++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index ce310a490..a83cb51d9 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -1,3 +1,194 @@ +@enumx NonlinearSafeTerminationReturnCode begin + Success + Default + PatienceTermination + ProtectiveTermination + Failure +end + +abstract type AbstractNonlinearTerminationMode end +abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end +abstract type AbstractSafeBestNonlinearTerminationMode <: + AbstractSafeNonlinearTerminationMode end + +# TODO: Add a mode where the user can pass in custom termination criteria function +for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode, + :NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode, + :AbsNormTerminationMode) + @eval begin + struct $(mode) <: AbstractNonlinearTerminationMode end + end +end + +for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode) + @eval begin + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 30 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 + end + end +end + +for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) + @eval begin + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 30 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 + end + end +end + +mutable struct NonlinearTerminationModeCache{uType, T, + M <: AbstractNonlinearTerminationMode, I, OT} + u::uType + retcode::NonlinearSafeTerminationReturnCode.T + abstol::T + reltol::T + best_objective_value::T + mode::M + initial_objective::I + objectives_trace::OT + nsteps::Int +end + +function __update_u!!(cache::NonlinearTerminationModeCache, u) + cache.u === nothing && return + if ArrayInterface.can_setindex(cache.u) + copyto!(cache.u, u) + else + cache.u = u + end +end + +__cvt_real(::Type{T}, ::Nothing) where {T} = nothing +__cvt_real(::Type{T}, x) where {T} = real(T(x)) + +_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η) +function _get_tolerance(::Nothing, ::Type{T}) where {T} + η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return _get_tolerance(η, T) +end + +function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode; + abstol = nothing, reltol = nothing, kwargs...) where {T} + abstol = _get_tolerance(abstol, T) + reltol = _get_tolerance(reltol, T) + best_value = __cvt_real(T, Inf) + TT = typeof(abstol) + u_ = mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + if mode isa AbstractSafeNonlinearTerminationMode + initial_objective = TT(0) + objectives_trace = Vector{TT}(undef, mode.patience_steps) + else + initial_objective = nothing + objectives_trace = nothing + end + return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), + typeof(initial_objective), typeof(objectives_trace)}(u_, + NonlinearSafeTerminationReturnCode.Default, abstol, reltol, best_value, mode, + initial_objective, objectives_trace, 0) +end + +# This dispatch is needed based on how Terminating Callback works! +# This intentially drops the `abstol` and `reltol` arguments +function (cache::NonlinearTerminationModeCache)(integrator, _, _, min_t) + return cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev) +end +(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev) + +function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, + u, uprev) + return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) +end + +function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, + du, u, uprev) + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + objective = NLSOLVE_DEFAULT_NORM(du) + criteria = cache.abstol + else + objective = NLSOLVE_DEFAULT_NORM(du) / + (NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol)) + criteria = cache.reltol + end + + # Check if best solution + if mode isa AbstractSafeBestNonlinearTerminationMode && + objective < cache.best_objective_value + cache.best_objective_value = objective + __update_u!!(cache, u) + end + + # Main Termination Condition + if objective ≤ criteria + cache.retcode = NonlinearSafeTerminationReturnCode.Success + return true + end + + # Terminate if there has been no improvement for the last `patience_steps` + cache.nsteps += 1 + cache.nsteps == 1 && (cache.initial_objective = objective) + cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective + + if objective ≤ cache.mode.patience_objective_multiplier * criteria + if cache.nsteps ≥ cache.mode.patience_steps + if cache.nsteps < length(cache.objectives_trace) + min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps])) + else + min_obj, max_obj = extrema(cache.objectives_trace) + end + if min_obj < cache.mode.min_max_factor * max_obj + cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination + return true + end + end + end + + # Protective Break + if objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du) + cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + return true + end + + cache.retcode = NonlinearSafeTerminationReturnCode.Failure + return false +end + +function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, + reltol) + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) +end +function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, + reltol) + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) || + isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) +end +function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + du_norm = NLSOLVE_DEFAULT_NORM(duₙ) + return du_norm ≤ abstol || du_norm ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) +end +function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) +end +function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, + RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) +end +function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return all(abs.(duₙ) .≤ abstol) +end +function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, + AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ abstol +end + +# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type +# instability @enumx NLSolveSafeTerminationReturnCode begin Success PatienceTermination From 7e3610e759c096e9776e5d146ef407cbd2bd9a6b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Oct 2023 23:00:18 -0400 Subject: [PATCH 2/8] use full form --- src/common_defaults.jl | 2 +- src/termination_conditions.jl | 32 +++++++++++++++++--------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/common_defaults.jl b/src/common_defaults.jl index 2cf12e3c5..a17d81f21 100644 --- a/src/common_defaults.jl +++ b/src/common_defaults.jl @@ -42,7 +42,7 @@ end end @inline ODE_DEFAULT_NORM(u, t) = norm(u) -@inline NLSOLVE_DEFAULT_NORM(u) = ODE_DEFAULT_NORM(u, nothing) +@inline NONLINEARSOLVE_DEFAULT_NORM(u) = ODE_DEFAULT_NORM(u, nothing) @inline ODE_DEFAULT_ISOUTOFDOMAIN(u, p, t) = false @inline function ODE_DEFAULT_PROG_MESSAGE(dt, u::Array, p, t) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index a83cb51d9..a24c1a434 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -109,11 +109,11 @@ end function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, du, u, uprev) if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode - objective = NLSOLVE_DEFAULT_NORM(du) + objective = NONLINEARSOLVE_DEFAULT_NORM(du) criteria = cache.abstol else - objective = NLSOLVE_DEFAULT_NORM(du) / - (NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol)) + objective = NONLINEARSOLVE_DEFAULT_NORM(du) / + (NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol)) criteria = cache.reltol end @@ -138,7 +138,7 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi if objective ≤ cache.mode.patience_objective_multiplier * criteria if cache.nsteps ≥ cache.mode.patience_steps if cache.nsteps < length(cache.objectives_trace) - min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps])) + min_obj, max_obj = extrema(@view(cache.objectives_trace[1:(cache.nsteps)])) else min_obj, max_obj = extrema(cache.objectives_trace) end @@ -169,22 +169,23 @@ function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, u isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) end function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) - du_norm = NLSOLVE_DEFAULT_NORM(duₙ) - return du_norm ≤ abstol || du_norm ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) + du_norm = NONLINEARSOLVE_DEFAULT_NORM(duₙ) + return du_norm ≤ abstol || du_norm ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ) end function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) end function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) - return NLSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) + return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ + reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ) end function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) return all(abs.(duₙ) .≤ abstol) end function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) - return NLSOLVE_DEFAULT_NORM(duₙ) ≤ abstol + return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ abstol end # NOTE: Deprecate the following API eventually. This API leads to quite a bit of type @@ -363,11 +364,11 @@ function (cond::NLSolveTerminationCondition)(storage::Union{ end if mode ∈ SAFE_BEST_TERMINATION_MODES - objective = NLSOLVE_DEFAULT_NORM(du) + objective = NONLINEARSOLVE_DEFAULT_NORM(du) criteria = abstol else - objective = NLSOLVE_DEFAULT_NORM(du) / - (NLSOLVE_DEFAULT_NORM(du .+ u) + eps(aType)) + objective = NONLINEARSOLVE_DEFAULT_NORM(du) / + (NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(aType)) criteria = reltol end @@ -426,18 +427,19 @@ end @inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol) if mode == NLSolveTerminationMode.Norm - du_norm = NLSOLVE_DEFAULT_NORM(du) - return du_norm ≤ abstol || du_norm ≤ reltol * NLSOLVE_DEFAULT_NORM(du + u) + du_norm = NONLINEARSOLVE_DEFAULT_NORM(du) + return du_norm ≤ abstol || du_norm ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(du + u) elseif mode == NLSolveTerminationMode.Rel return all(abs.(du) .≤ reltol .* abs.(u)) elseif mode ∈ (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe, NLSolveTerminationMode.RelSafeBest) - return NLSOLVE_DEFAULT_NORM(du) ≤ reltol * NLSOLVE_DEFAULT_NORM(du .+ u) + return NONLINEARSOLVE_DEFAULT_NORM(du) ≤ + reltol * NONLINEARSOLVE_DEFAULT_NORM(du .+ u) elseif mode == NLSolveTerminationMode.Abs return all(abs.(du) .≤ abstol) elseif mode ∈ (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe, NLSolveTerminationMode.AbsSafeBest) - return NLSOLVE_DEFAULT_NORM(du) ≤ abstol + return NONLINEARSOLVE_DEFAULT_NORM(du) ≤ abstol elseif mode == NLSolveTerminationMode.SteadyStateDefault return all((abs.(du) .≤ abstol) .| (abs.(du) .≤ reltol .* abs.(u))) elseif mode == NLSolveTerminationMode.NLSolveDefault From 9b2890fa070eb9fe46af1d11cc515d90822bb4fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 13:19:31 -0400 Subject: [PATCH 3/8] Extend the API a bit more --- Project.toml | 2 +- src/DiffEqBase.jl | 8 ++++++-- src/termination_conditions.jl | 9 +++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index a41b61377..1a4f33d1a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.135.0" +version = "6.136.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 0ece946cb..c73daadc0 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -171,9 +171,13 @@ export initialize!, finalize! export SensitivityADPassThrough +export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode, + NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode, + AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode, + RelSafeBestTerminationMode, AbsSafeBestTerminationMode +# Deprecated API export NLSolveTerminationMode, - NLSolveSafeTerminationOptions, NLSolveTerminationCondition, - NLSolveSafeTerminationResult + NLSolveSafeTerminationOptions, NLSolveTerminationCondition, NLSolveSafeTerminationResult export KeywordArgError, KeywordArgWarn, KeywordArgSilent diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index a24c1a434..487bb8802 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -55,6 +55,10 @@ mutable struct NonlinearTerminationModeCache{uType, T, nsteps::Int end +get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode +get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol +get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol + function __update_u!!(cache::NonlinearTerminationModeCache, u) cache.u === nothing && return if ArrayInterface.can_setindex(cache.u) @@ -73,8 +77,9 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T} return _get_tolerance(η, T) end -function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode; - abstol = nothing, reltol = nothing, kwargs...) where {T} +function SciMLBase.init(u::Union{AbstractArray{T}, T}, + mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, + kwargs...) where {T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) best_value = __cvt_real(T, Inf) From 89c784a556f6a9200c4dccbf2f6d264df3080d97 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 15:39:55 -0400 Subject: [PATCH 4/8] Better initial objective --- src/termination_conditions.jl | 36 +++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 487bb8802..8ddd911b3 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -61,7 +61,7 @@ get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol function __update_u!!(cache::NonlinearTerminationModeCache, u) cache.u === nothing && return - if ArrayInterface.can_setindex(cache.u) + if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u) copyto!(cache.u, u) else cache.u = u @@ -77,21 +77,27 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T} return _get_tolerance(η, T) end -function SciMLBase.init(u::Union{AbstractArray{T}, T}, +function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, kwargs...) where {T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) - best_value = __cvt_real(T, Inf) TT = typeof(abstol) u_ = mode isa AbstractSafeBestNonlinearTerminationMode ? (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing if mode isa AbstractSafeNonlinearTerminationMode - initial_objective = TT(0) + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du) + else + initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du) / + (NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(TT)) + end objectives_trace = Vector{TT}(undef, mode.patience_steps) + best_value = initial_objective else initial_objective = nothing objectives_trace = nothing + best_value = __cvt_real(T, Inf) end return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), typeof(initial_objective), typeof(objectives_trace)}(u_, @@ -122,6 +128,13 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi criteria = cache.reltol end + # Protective Break + if isinf(objective) || isnan(objective) || + (objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du)) + cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + return true + end + # Check if best solution if mode isa AbstractSafeBestNonlinearTerminationMode && objective < cache.best_objective_value @@ -154,12 +167,6 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi end end - # Protective Break - if objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du) - cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination - return true - end - cache.retcode = NonlinearSafeTerminationReturnCode.Failure return false end @@ -238,9 +245,10 @@ function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64, best_objective_value_iteration = 0, return_code = NLSolveSafeTerminationReturnCode.Failure) u = u !== nothing ? copy(u) : u + Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + :NLSolveSafeTerminationResult) return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u, - best_objective_value, - best_objective_value_iteration, return_code) + best_objective_value, best_objective_value_iteration, return_code) end const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault, @@ -296,6 +304,8 @@ Define the termination criteria for the NonlinearProblem or SteadyStateProblem. * `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately. * `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate. +!!! warning + This has been deprecated and will be removed in the next major release. Please use the new dispatch based termination conditions API. """ struct NLSolveTerminationCondition{mode, T, S <: Union{<:NLSolveSafeTerminationOptions, Nothing}} @@ -323,6 +333,8 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, protective_threshold = 1e3, patience_steps::Int = 30, patience_objective_multiplier = 3, min_max_factor = 1.3) where {T} + Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + :NLSolveTerminationCondition) @assert mode ∈ instances(NLSolveTerminationMode.T) options = if mode ∈ SAFE_TERMINATION_MODES NLSolveSafeTerminationOptions(protective_threshold, patience_steps, From a510a8ff9c6087518bfccdcecee0c3b0fef02ba8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 15:56:34 -0400 Subject: [PATCH 5/8] Fix dispatch --- src/termination_conditions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 8ddd911b3..18a5b144a 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -184,7 +184,7 @@ function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, absto du_norm = NONLINEARSOLVE_DEFAULT_NORM(duₙ) return du_norm ≤ abstol || du_norm ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ) end -function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) +function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) end function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, From 26b99eb2de9beb457984e883a88efda3f4135a8e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 16:04:54 -0400 Subject: [PATCH 6/8] Update default --- src/termination_conditions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 18a5b144a..7917c9e83 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -35,7 +35,7 @@ for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) @eval begin Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode protective_threshold::T1 = 1000 - patience_steps::Int = 30 + patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 end From 8fccdcca328edb4b135052b35ec7d9dcfb299181 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 16:34:59 -0400 Subject: [PATCH 7/8] Docs --- src/termination_conditions.jl | 230 +++++++++++++++++++++++++++------- 1 file changed, 188 insertions(+), 42 deletions(-) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 7917c9e83..67d19c2fb 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -1,8 +1,39 @@ +""" + NonlinearSafeTerminationReturnCode + +Return Codes for the safe nonlinear termination conditions. +""" @enumx NonlinearSafeTerminationReturnCode begin + """ + NonlinearSafeTerminationReturnCode.Success + + Termination Condition was satisfied! + """ Success + """ + NonlinearSafeTerminationReturnCode.Default + + Default Return Code. Used for type stability and conveys no additional information! + """ Default + """ + NonlinearSafeTerminationReturnCode.PatienceTermination + + Terminate if there has been no improvement for the last `patience_steps`. + """ PatienceTermination + """ + NonlinearSafeTerminationReturnCode.ProtectiveTermination + + Terminate if the objective value increased by this factor wrt initial objective or the + value diverged. + """ ProtectiveTermination + """ + NonlinearSafeTerminationReturnCode.Failure + + Termination Condition was not satisfied! + """ Failure end @@ -12,34 +43,149 @@ abstract type AbstractSafeBestNonlinearTerminationMode <: AbstractSafeNonlinearTerminationMode end # TODO: Add a mode where the user can pass in custom termination criteria function -for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode, - :NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode, - :AbsNormTerminationMode) - @eval begin - struct $(mode) <: AbstractNonlinearTerminationMode end - end + +""" + SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. + +The default used in SteadyStateDiffEq.jl! Not recommended for large problems, since the +convergence criteria is very strict and never reliably satisfied for most problems. +""" +struct SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode end + +""" + SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. Or check that the value of the current and previous state is within the specified +tolerances. + +The default used in SimpleNonlinearSolve.jl! Not recommended for large problems, since the +convergence criteria is very strict and never reliably satisfied for most problems. +""" +struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + NormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +or ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +""" +struct NormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``. +""" +struct RelTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelNormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if +``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +""" +struct RelNormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + AbsTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``. +""" +struct AbsTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + AbsNormTerminationMode <: AbstractNonlinearTerminationMode + +Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``. +""" +struct AbsNormTerminationMode <: AbstractNonlinearTerminationMode end + +@doc doc""" + RelSafeTerminationMode <: AbstractSafeNonlinearTerminationMode + +Essentially [`RelNormTerminationMode`](@ref) + terminate if there has been no improvement +for the last `patience_steps` + terminate if the solution blows up (diverges). + +## Constructor + +```julia +RelSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end -for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode) - @eval begin - Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode - protective_threshold::T1 = 1000 - patience_steps::Int = 30 - patience_objective_multiplier::T2 = 3 - min_max_factor::T3 = 1.3 - end - end +@doc doc""" + AbsSafeTerminationMode <: AbstractSafeNonlinearTerminationMode + +Essentially [`AbsNormTerminationMode`](@ref) + terminate if there has been no improvement +for the last `patience_steps` + terminate if the solution blows up (diverges). + +## Constructor + +```julia +AbsSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end -for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) - @eval begin - Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode - protective_threshold::T1 = 1000 - patience_steps::Int = 100 - patience_objective_multiplier::T2 = 3 - min_max_factor::T3 = 1.3 - end - end +@doc doc""" + RelSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode + +Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found so far. + +## Constructor + +```julia +RelSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 +end + +@doc doc""" + AbsSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode + +Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found so far. + +## Constructor + +```julia +AbsSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100, + patience_objective_multiplier = 3, min_max_factor = 1.3) +``` +""" +Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <: + AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 100 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 end mutable struct NonlinearTerminationModeCache{uType, T, @@ -78,8 +224,8 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T} end function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, - mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, - kwargs...) where {T <: Number} + mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing, + kwargs...) where {T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) TT = typeof(abstol) @@ -113,12 +259,12 @@ end (cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev) function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, - u, uprev) + u, uprev) return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) end function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, - du, u, uprev) + du, u, uprev) if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode objective = NONLINEARSOLVE_DEFAULT_NORM(du) criteria = cache.abstol @@ -130,7 +276,7 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Protective Break if isinf(objective) || isnan(objective) || - (objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du)) + (objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du)) cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination return true end @@ -172,11 +318,11 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi end function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, - reltol) + reltol) return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) end function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, - reltol) + reltol) return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) || isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) end @@ -188,7 +334,7 @@ function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) end function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, - RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ) end @@ -196,7 +342,7 @@ function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol return all(abs.(duₙ) .≤ abstol) end function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, - AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) return NONLINEARSOLVE_DEFAULT_NORM(duₙ) ≤ abstol end @@ -242,8 +388,8 @@ mutable struct NLSolveSafeTerminationResult{T, uType} end function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64, - best_objective_value_iteration = 0, - return_code = NLSolveSafeTerminationReturnCode.Failure) + best_objective_value_iteration = 0, + return_code = NLSolveSafeTerminationReturnCode.Failure) u = u !== nothing ? copy(u) : u Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveSafeTerminationResult) @@ -330,9 +476,9 @@ get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode # Don't specify `mode` since the defaults would depend on the package function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, - protective_threshold = 1e3, patience_steps::Int = 30, - patience_objective_multiplier = 3, - min_max_factor = 1.3) where {T} + protective_threshold = 1e3, patience_steps::Int = 30, + patience_objective_multiplier = 3, + min_max_factor = 1.3) where {T} Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveTerminationCondition) @assert mode ∈ instances(NLSolveTerminationMode.T) @@ -346,9 +492,9 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, end function (cond::NLSolveTerminationCondition)(storage::Union{ - NLSolveSafeTerminationResult, - Nothing, -}) + NLSolveSafeTerminationResult, + Nothing, + }) mode = get_termination_mode(cond) # We need both the dispatches to support solvers that don't use the integrator # interface like SimpleNonlinearSolve @@ -438,7 +584,7 @@ end # Convergence Criteria @inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode}, - abstol = cond.abstol, reltol = cond.reltol) where {mode} + abstol = cond.abstol, reltol = cond.reltol) where {mode} return _has_converged(du, u, uprev, mode, abstol, reltol) end From 43206a74c280d8f7346b5efbe90e0750a7cfdcee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 19:49:44 -0400 Subject: [PATCH 8/8] Cancel intermediate builds --- .github/workflows/CI.yml | 5 +++++ .github/workflows/Downstream.yml | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 866c78058..79b02d7cb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -6,6 +6,11 @@ on: push: branches: - master +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: runs-on: ubuntu-latest diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 812a3aa63..efffee70d 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -4,7 +4,11 @@ on: branches: [master] tags: [v*] pull_request: - +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: name: ${{ matrix.package.repo }}/${{ matrix.package.group }}/${{ matrix.julia-version }}