Skip to content

Commit

Permalink
Use a mutable struct instead of Dict for Safe termination
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 12, 2023
1 parent 3b6e39b commit 0e1ce2a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.jl.*.mem
Manifest.toml
.DS_Store
.vscode
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.122.2"
version = "6.123.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
3 changes: 2 additions & 1 deletion src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ export initialize!, finalize!

export SensitivityADPassThrough

export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition,
NLSolveSafeTerminationResult

export KeywordArgError, KeywordArgWarn, KeywordArgSilent

Expand Down
45 changes: 36 additions & 9 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ struct NLSolveSafeTerminationOptions{T1, T2, T3}
min_max_factor::T3
end

TruncatedStacktraces.@truncate_stacktrace NLSolveSafeTerminationOptions

Base.@kwdef mutable struct NLSolveSafeTerminationResult{T}
best_objective_value::T = Inf64
best_objective_value_iteration::Int = 0
return_code::NLSolveSafeTerminationReturnCode.T = NLSolveSafeTerminationReturnCode.Failure
end

# Remove once support for AbstractDict has been dropped
function __setproperty!(n::NLSolveSafeTerminationResult, prop::Symbol, value)
setproperty!(n, prop, value)
end
function __setproperty!(d::AbstractDict, prop::Symbol, value)
d[prop] = value
end

const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
NLSolveTerminationMode.NLSolveDefault,
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
Expand Down Expand Up @@ -89,6 +105,8 @@ struct NLSolveTerminationCondition{mode, T,
safe_termination_options::S
end

TruncatedStacktraces.@truncate_stacktrace NLSolveTerminationCondition 1

function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
print(io,
"NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
Expand Down Expand Up @@ -116,8 +134,14 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options)
end

function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing})
function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict,
NLSolveSafeTerminationResult,
Nothing})
mode = get_termination_mode(cond)
if storage isa AbstractDict
Base.depwarn("`storage` of type ($(typeof(storage)) <: AbstractDict) has been deprecated. Pass in a `NLSolveSafeTerminationResult` instance instead",
:NLSolveTerminationCondition)
end
# We need both the dispatches to support solvers that don't use the integrator
# interface like SimpleNonlinearSolve
if mode in BASIC_TERMINATION_MODES
Expand All @@ -144,8 +168,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier

if mode SAFE_BEST_TERMINATION_MODES
storage[:best_objective_value] = aType(Inf)
storage[:best_objective_value_iteration] = 0
__setproperty!(storage, :best_objective_value, aType(Inf))
__setproperty!(storage, :best_objective_value_iteration, 0)
end

if mode SAFE_BEST_TERMINATION_MODES
Expand All @@ -158,14 +182,15 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth

if mode SAFE_BEST_TERMINATION_MODES
if objective < storage[:best_objective_value]
storage[:best_objective_value] = objective
storage[:best_objective_value_iteration] = nstep + 1
__setproperty!(storage, :best_objective_value, objective)
__setproperty!(storage, :best_objective_value_iteration, nstep + 1)
end
end

# Main Termination Criteria
if objective <= criteria
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.Success)
return true
end

Expand All @@ -181,19 +206,21 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
if maximum(last_k_values) <
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
minimum(last_k_values)
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.PatienceTermination)
return true
end
end
end

# Protective break
if objective >= objective_values[1] * protective_threshold * length(du)
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.ProtectiveTermination)
return true
end

storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
__setproperty!(storage, :return_code, NLSolveSafeTerminationReturnCode.Failure)
return false
end
return _termination_condition_closure_safe
Expand Down

0 comments on commit 0e1ce2a

Please sign in to comment.