Skip to content

Commit

Permalink
Merge pull request #208 from utkarsh530/u/termination_condition
Browse files Browse the repository at this point in the history
Start using termination conditions from DiffEqBase
  • Loading branch information
ChrisRackauckas authored Oct 26, 2023
2 parents 191a237 + 350fac5 commit dbed34c
Show file tree
Hide file tree
Showing 13 changed files with 548 additions and 92 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]
53 changes: 44 additions & 9 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end
f
alg
u
u_prev
du
fu
fu2
Expand All @@ -46,17 +47,21 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
reset_tolerance
reset_check
prob
stats::NLStats
lscache
termination_condition
tc_storage
end

get_fu(cache::GeneralBroydenCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
T = eltype(u)

mul!(_vec(du), J⁻¹, -_vec(fu))
α = perform_linesearch!(cache.lscache, u, du)
_axpy!(α, du, u)
f(fu2, u, p)

cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = cache.u .+ α * cache.du
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
41 changes: 34 additions & 7 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,16 @@ end
internalnorm
retcode::SciMLBase.ReturnCode.T
abstol
reltol
prob
stats::NLStats
termination_condition
tc_storage
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Expand Down Expand Up @@ -122,14 +126,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁

= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
T)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
termination_condition, storage)
end

function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -174,7 +191,7 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

Expand Down Expand Up @@ -205,8 +222,9 @@ function perform_step!(cache::DFSaneCache{true})
end

function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -249,7 +267,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

Expand Down Expand Up @@ -296,7 +314,9 @@ function SciMLBase.solve!(cache::DFSaneCache)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
Expand All @@ -317,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
62 changes: 52 additions & 10 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ end
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
fu_new
Expand All @@ -72,12 +75,17 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
stats::NLStats
tc_storage
termination_condition
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand All @@ -101,15 +109,29 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
JᵀJ, Jᵀf = nothing, nothing
end

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u); mode = NLSolveTerminationMode.AbsNorm)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
reltol,
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache
jacobian!!(J, cache)

termination_condition = cache.termination_condition(tc_storage)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
Expand All @@ -127,9 +149,15 @@ function perform_step!(cache::GaussNewtonCache{true})
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)

@. u_prev = u
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -139,7 +167,9 @@ function perform_step!(cache::GaussNewtonCache{true})
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)

Expand All @@ -164,7 +194,10 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
Expand All @@ -174,7 +207,9 @@ function perform_step!(cache::GaussNewtonCache{false})
end

function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -184,7 +219,14 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
Loading

0 comments on commit dbed34c

Please sign in to comment.