Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Add Termination Conditions to Broyden #45

Merged
merged 4 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ SimpleBatchedNonlinearSolveExt = "NNlib"

[compat]
ArrayInterface = "6, 7"
DiffEqBase = "6.114"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
NNlib = "0.8"
Expand Down
28 changes: 18 additions & 10 deletions ext/SimpleBatchedNonlinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SimpleBatchedNonlinearSolveExt

using ArrayInterface, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase

isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)

_batch_transpose(x) = reshape(x, 1, size(x)...)
Expand Down Expand Up @@ -31,6 +32,8 @@ end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

Expand All @@ -47,8 +50,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
(tc.abstol !== nothing ? tc.abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
rtol = reltol !== nothing ? reltol :
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))

if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("Broyden currently doesn't support SAFE_BEST termination modes")
end

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
termination_condition = tc(storage)

xₙ = x
xₙ₋₁ = x
Expand All @@ -63,14 +75,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))

iszero(fₙ) &&
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)

if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
end

xₙ₋₁ = xₙ
fₙ₋₁ = fₙ
end
Expand Down
41 changes: 29 additions & 12 deletions src/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Broyden(; batched = false)
Broyden(; batched = false,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing, reltol = nothing))

A low-overhead implementation of Broyden. This method is non-allocating on scalar
and static array problems.
Expand All @@ -9,12 +11,22 @@ and static array problems.
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
`import NNlib` must be present in your code.
"""
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
Broyden(; batched = false) = new{batched}()
struct Broyden{batched, TC <: NLSolveTerminationCondition} <:
AbstractSimpleNonlinearSolveAlgorithm
termination_condition::TC

function Broyden(; batched = false,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
return new{batched, typeof(termination_condition)}(termination_condition)
end
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

Expand All @@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
(tc.abstol !== nothing ? tc.abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
rtol = reltol !== nothing ? reltol :
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))

if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("Broyden currently doesn't support SAFE_BEST termination modes")
end

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
termination_condition = tc(storage)

xₙ = x
xₙ₋₁ = x
Expand All @@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
J⁻¹Δfₙ = J⁻¹ * Δfₙ
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)

iszero(fₙ) &&
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)

if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
retcode = ReturnCode.Success)
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
end

xₙ₋₁ = xₙ
fₙ₋₁ = fₙ
end
Expand Down
90 changes: 58 additions & 32 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
using SimpleNonlinearSolve
using StaticArrays
using BenchmarkTools
using DiffEqBase
using Test

const BATCHED_BROYDEN_SOLVERS = Broyden[]
const BROYDEN_SOLVERS = Broyden[]

for mode in instances(NLSolveTerminationMode.T)
if mode ∈
(NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest,
NLSolveTerminationMode.AbsSafeBest)
continue
end

termination_condition = NLSolveTerminationCondition(mode; abstol = nothing,
reltol = nothing)
push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition))
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
end

# SimpleNewtonRaphson
function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
Expand Down Expand Up @@ -50,16 +67,19 @@ if VERSION >= v"1.7"
end

# Broyden
function benchmark_scalar(f, u0)
function benchmark_scalar(f, u0, alg)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, Broyden()))
sol = (solve(probN, alg))
end

sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
for alg in BROYDEN_SOLVERS
sol = benchmark_scalar(sf, csu0, alg)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
# FIXME: Termination Condition Implementation is allocating. Not sure how to fix it.
# if VERSION >= v"1.7"
# @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0
# end
end

# Klement
Expand Down Expand Up @@ -101,8 +121,8 @@ using ForwardDiff
# Immutable
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), BROYDEN_SOLVERS...)
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, alg, abstol = 1e-9)
Expand All @@ -117,8 +137,8 @@ end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, alg)
Expand Down Expand Up @@ -183,8 +203,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()]
@test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p)
end

for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley())
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
Expand All @@ -199,14 +219,15 @@ end
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, SimpleNewtonRaphson()).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleTrustRegion()).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] ≈ sqrt(2.0)
@test solve(probN, Broyden()).u[end] ≈ sqrt(2.0)
@test solve(probN, LBroyden()).u[end] ≈ sqrt(2.0)
@test solve(probN, Klement()).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleDFSane()).u[end] ≈ sqrt(2.0)
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
SimpleTrustRegion(),
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
BROYDEN_SOLVERS...)
sol = solve(probN, alg)

@test sol.retcode == ReturnCode.Success
@test sol.u[end] ≈ sqrt(2.0)
end

# Separate Error check for Halley; will be included in above error checks for the improved Halley
f, u0 = (u, p) -> u * u - 2.0, 1.0
Expand All @@ -220,18 +241,16 @@ for u0 in [1.0, [1, 1.0]]
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0

@test solve(probN, SimpleNewtonRaphson()).u ≈ sol
@test solve(probN, SimpleNewtonRaphson()).u ≈ sol
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u ≈ sol

@test solve(probN, SimpleTrustRegion()).u ≈ sol
@test solve(probN, SimpleTrustRegion()).u ≈ sol
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u ≈ sol
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
SimpleTrustRegion(),
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(),
SimpleDFSane(),
BROYDEN_SOLVERS...)
sol2 = solve(probN, alg)

@test solve(probN, Broyden()).u ≈ sol
@test solve(probN, LBroyden()).u ≈ sol
@test solve(probN, Klement()).u ≈ sol
@test solve(probN, SimpleDFSane()).u ≈ sol
@test sol2.retcode == ReturnCode.Success
@test sol2.u ≈ sol
end
end

# Bisection Tests
Expand Down Expand Up @@ -411,3 +430,10 @@ probN = NonlinearProblem{false}(f, u0, p);
sol = solve(probN, Broyden(batched = true))

@test abs.(sol.u) ≈ sqrt.(p)

for alg in BATCHED_BROYDEN_SOLVERS
sol = solve(probN, alg)

@test sol.retcode == ReturnCode.Success
@test abs.(sol.u) ≈ sqrt.(p)
end