Skip to content

Commit

Permalink
Merge pull request #243 from SciML/dw/terminate_steadystate
Browse files Browse the repository at this point in the history
A few improvements of TerminateSteadyState
  • Loading branch information
ChrisRackauckas authored Dec 1, 2024
2 parents c08609c + 8cd54ab commit 88f8221
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions src/terminatesteadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# Terminate when all derivatives fall below a threshold or
# when derivatives are smaller than a fraction of state
function allDerivPass(integrator, abstol, reltol, min_t)
# Early exit
if min_t !== nothing && integrator.t < min_t
return false
end

if DiffEqBase.isinplace(integrator.sol.prob)
testval = first(get_tmp_cache(integrator))
DiffEqBase.get_du!(testval, integrator)
Expand All @@ -16,21 +21,21 @@ function allDerivPass(integrator, abstol, reltol, min_t)
end

if integrator.u isa Array
any(abs(d) > abstol && abs(d) > reltol * abs(u)
return all(abs(d) <= max(abstol, reltol * abs(u))
for (d, abstol, reltol, u) in zip(testval, Iterators.cycle(abstol),
Iterators.cycle(reltol), integrator.u)) &&
(return false)
Iterators.cycle(reltol), integrator.u))
else
any((abs.(testval) .> abstol) .& (abs.(testval) .> reltol .* abs.(integrator.u))) &&
(return false)
return all(abs.(testval) .<= max.(abstol, reltol .* abs.(integrator.u)))
end
end

if min_t === nothing
return true
else
return integrator.t >= min_t
end
struct WrappedTest{T, A, R, M}
test::T
abstol::A
reltol::R
min_t::M
end
(f::WrappedTest)(u, t, integrator) = f.test(integrator, f.abstol, f.reltol, f.min_t)

# Allow user-defined tolerances and test functions but use sensible defaults
# test function must take integrator, time, followed by absolute
Expand All @@ -50,7 +55,7 @@ the [Steady State Solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/stead
These tolerances may be specified as scalars or as arrays of the same length
as the states of the problem.
- `test` represents the function that evaluates the condition for termination. The default
condition is that all derivatives should become smaller than `abstol` and the states times
condition is that all derivatives should become smaller than `abstol` or the states times
`reltol`. The user can pass any other function to implement a different termination condition.
Such function should take four arguments: `integrator`, `abstol`, `reltol`, and `min_t`.
- `wrap_test` can be set to `Val(false)`, in which case `test` must have the definition
Expand All @@ -62,15 +67,14 @@ the [Steady State Solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/stead
- `min_t` specifies an optional minimum `t` before the steady state calculations are allowed
to terminate.
"""
function TerminateSteadyState(abstol = 1e-8, reltol = 1e-6, test = allDerivPass;
min_t = nothing, wrap_test::Val{WT} = Val(true)) where {WT}
function TerminateSteadyState(abstol = 1e-8, reltol = 1e-6, test::T = allDerivPass;
min_t = nothing, wrap_test::Val{WT} = Val(true)) where {T, WT}
condition = if WT
(u, t, integrator) -> test(integrator, abstol, reltol, min_t)
WrappedTest(test, abstol, reltol, min_t)
else
test
end
affect! = (integrator) -> terminate!(integrator)
DiscreteCallback(condition, affect!; save_positions = (true, false))
DiscreteCallback(condition, terminate!; save_positions = (true, false))
end

export TerminateSteadyState

0 comments on commit 88f8221

Please sign in to comment.