Skip to content

Commit

Permalink
Merge pull request #45 from avik-pal/ap/termination_broyden
Browse files Browse the repository at this point in the history
Add Termination Conditions to Broyden
  • Loading branch information
ChrisRackauckas authored Feb 19, 2023
2 parents 6cdd257 + e0ff9fc commit 49a4ce8
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 55 deletions.
1 change: 0 additions & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ SimpleBatchedNonlinearSolveExt = "NNlib"

[compat]
ArrayInterface = "6, 7"
DiffEqBase = "6.114"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
NNlib = "0.8"
Expand Down
28 changes: 18 additions & 10 deletions lib/SimpleNonlinearSolve/ext/SimpleBatchedNonlinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SimpleBatchedNonlinearSolveExt

using ArrayInterface, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase

isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)

_batch_transpose(x) = reshape(x, 1, size(x)...)
Expand Down Expand Up @@ -31,6 +32,8 @@ end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

Expand All @@ -47,8 +50,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
(tc.abstol !== nothing ? tc.abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
rtol = reltol !== nothing ? reltol :
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("Broyden currently doesn't support SAFE_BEST termination modes")
end

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
termination_condition = tc(storage)

xₙ = x
xₙ₋₁ = x
Expand All @@ -63,14 +75,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))

iszero(fₙ) &&
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)

if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
end

xₙ₋₁ = xₙ
fₙ₋₁ = fₙ
end
Expand Down
41 changes: 29 additions & 12 deletions lib/SimpleNonlinearSolve/src/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Broyden(; batched = false)
Broyden(; batched = false,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing, reltol = nothing))
A low-overhead implementation of Broyden. This method is non-allocating on scalar
and static array problems.
Expand All @@ -9,12 +11,22 @@ and static array problems.
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
`import NNlib` must be present in your code.
"""
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
Broyden(; batched = false) = new{batched}()
struct Broyden{batched, TC <: NLSolveTerminationCondition} <:
AbstractSimpleNonlinearSolveAlgorithm
termination_condition::TC

function Broyden(; batched = false,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
return new{batched, typeof(termination_condition)}(termination_condition)
end
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

Expand All @@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
(tc.abstol !== nothing ? tc.abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
rtol = reltol !== nothing ? reltol :
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("Broyden currently doesn't support SAFE_BEST termination modes")
end

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
termination_condition = tc(storage)

xₙ = x
xₙ₋₁ = x
Expand All @@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
J⁻¹Δfₙ = J⁻¹ * Δfₙ
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)

iszero(fₙ) &&
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)

if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
end

xₙ₋₁ = xₙ
fₙ₋₁ = fₙ
end
Expand Down
90 changes: 58 additions & 32 deletions lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
using SimpleNonlinearSolve
using StaticArrays
using BenchmarkTools
using DiffEqBase
using Test

const BATCHED_BROYDEN_SOLVERS = Broyden[]
const BROYDEN_SOLVERS = Broyden[]

for mode in instances(NLSolveTerminationMode.T)
if mode
(NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest,
NLSolveTerminationMode.AbsSafeBest)
continue
end

termination_condition = NLSolveTerminationCondition(mode; abstol = nothing,
reltol = nothing)
push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition))
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
end

# SimpleNewtonRaphson
function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
Expand Down Expand Up @@ -50,16 +67,19 @@ if VERSION >= v"1.7"
end

# Broyden
function benchmark_scalar(f, u0)
function benchmark_scalar(f, u0, alg)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, Broyden()))
sol = (solve(probN, alg))
end

sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
for alg in BROYDEN_SOLVERS
sol = benchmark_scalar(sf, csu0, alg)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
# FIXME: Termination Condition Implementation is allocating. Not sure how to fix it.
# if VERSION >= v"1.7"
# @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0
# end
end

# Klement
Expand Down Expand Up @@ -101,8 +121,8 @@ using ForwardDiff
# Immutable
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), BROYDEN_SOLVERS...)
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, alg, abstol = 1e-9)
Expand All @@ -117,8 +137,8 @@ end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, alg)
Expand Down Expand Up @@ -183,8 +203,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()]
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
end

for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
Expand All @@ -199,14 +219,15 @@ end
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, SimpleTrustRegion()).u[end] sqrt(2.0)
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, Broyden()).u[end] sqrt(2.0)
@test solve(probN, LBroyden()).u[end] sqrt(2.0)
@test solve(probN, Klement()).u[end] sqrt(2.0)
@test solve(probN, SimpleDFSane()).u[end] sqrt(2.0)
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
SimpleTrustRegion(),
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
BROYDEN_SOLVERS...)
sol = solve(probN, alg)

@test sol.retcode == ReturnCode.Success
@test sol.u[end] sqrt(2.0)
end

# Separate Error check for Halley; will be included in above error checks for the improved Halley
f, u0 = (u, p) -> u * u - 2.0, 1.0
Expand All @@ -220,18 +241,16 @@ for u0 in [1.0, [1, 1.0]]
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0

@test solve(probN, SimpleNewtonRaphson()).u sol
@test solve(probN, SimpleNewtonRaphson()).u sol
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol

@test solve(probN, SimpleTrustRegion()).u sol
@test solve(probN, SimpleTrustRegion()).u sol
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u sol
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
SimpleTrustRegion(),
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(),
SimpleDFSane(),
BROYDEN_SOLVERS...)
sol2 = solve(probN, alg)

@test solve(probN, Broyden()).u sol
@test solve(probN, LBroyden()).u sol
@test solve(probN, Klement()).u sol
@test solve(probN, SimpleDFSane()).u sol
@test sol2.retcode == ReturnCode.Success
@test sol2.u sol
end
end

# Bisection Tests
Expand Down Expand Up @@ -411,3 +430,10 @@ probN = NonlinearProblem{false}(f, u0, p);
sol = solve(probN, Broyden(batched = true))

@test abs.(sol.u) sqrt.(p)

for alg in BATCHED_BROYDEN_SOLVERS
sol = solve(probN, alg)

@test sol.retcode == ReturnCode.Success
@test abs.(sol.u) sqrt.(p)
end

0 comments on commit 49a4ce8

Please sign in to comment.