Skip to content

Commit

Permalink
Merge pull request #38 from CCsimon123/main
Browse files Browse the repository at this point in the history
Adding a Brent method
  • Loading branch information
ChrisRackauckas authored Jan 31, 2023
2 parents 883b789 + 790eb83 commit 6b5da4b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 3 deletions.
6 changes: 4 additions & 2 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("broyden.jl")
include("klement.jl")
include("trustRegion.jl")
include("ridder.jl")
include("brent.jl")
include("ad.jl")

import SnoopPrecompile
Expand All @@ -44,12 +45,13 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
=#

prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
for alg in (Bisection, Falsi, Ridder)
for alg in (Bisection, Falsi, Ridder, Brent)
solve(prob_brack, alg(), abstol = T(1e-2))
end
end end

# DiffEq styled algorithms
export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion
export Bisection, Brent, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson,
SimpleTrustRegion

end # module
114 changes: 114 additions & 0 deletions lib/SimpleNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
`Brent()`
A non-allocating Brent method
"""
struct Brent <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
maxiters = 1000,
kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
fa, fb = f(a), f(b)
ϵ = eps(convert(typeof(fa), 1.0))

if iszero(fa)
return SciMLBase.build_solution(prob, alg, a, fa;
retcode = ReturnCode.ExactSolutionLeft, left = a,
right = b)
end
if abs(fa) < abs(fb)
c = b
b = a
a = c
tmp = fa
fa = fb
fb = tmp
end

c = a
d = c
i = 1
cond = true
if !iszero(fb)
while i < maxiters
fc = f(c)
if fa != fc && fb != fc
# Inverse quadratic interpolation
s = a * fb * fc / ((fa - fb) * (fa - fc)) +
b * fa * fc / ((fb - fa) * (fb - fc)) +
c * fa * fb / ((fc - fa) * (fc - fb))
else
# Secant method
s = b - fb * (b - a) / (fb - fa)
end
if (s < min((3 * a + b) / 4, b) || s > max((3 * a + b) / 4, b)) ||
(cond && abs(s - b) abs(b - c) / 2) ||
(!cond && abs(s - b) abs(c - d) / 2) ||
(cond && abs(b - c) ϵ) ||
(!cond && abs(c - d) ϵ)
# Bisection method
s = (a + b) / 2
(s == a || s == b) &&
return SciMLBase.build_solution(prob, alg, a, fa;
retcode = ReturnCode.FloatingPointLimit,
left = a, right = b)
cond = true
else
cond = false
end
fs = f(s)
if iszero(fs)
if b < a
a = b
fa = fb
end
b = s
fb = fs
break
end
if fa * fs < 0
d = c
c = b
b = s
fb = fs
else
a = s
fa = fs
end
if abs(fa) < abs(fb)
d = c
c = b
b = a
a = c
fc = fb
fb = fa
fa = fc
end
i += 1
end
end

while i < maxiters
c = (a + b) / 2
if (c == a || c == b)
return SciMLBase.build_solution(prob, alg, a, fa;
retcode = ReturnCode.FloatingPointLimit,
left = a, right = b)
end
fc = f(c)
if iszero(fc)
b = c
fb = fc
else
a = c
fa = fc
end
i += 1
end

return SciMLBase.build_solution(prob, alg, a, fa; retcode = ReturnCode.MaxIters,
left = a, right = b)
end
26 changes: 25 additions & 1 deletion lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,22 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

# Brent
g = function (p)
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
sol = solve(probN, Brent())
return sol.left
end

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
for alg in [Bisection(), Falsi(), Ridder()]
for alg in [Bisection(), Falsi(), Ridder(), Brent()]
global g, p
g = function (p)
probN = IntervalNonlinearProblem{false}(f, tspan, p)
Expand Down Expand Up @@ -200,6 +212,18 @@ probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Ridder())
@test sol.left sqrt(2.0)

# Brent
sol = solve(probB, Brent())
@test sol.left sqrt(2.0)
tspan = (sqrt(2.0), 10.0)
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Brent())
@test sol.left sqrt(2.0)
tspan = (0.0, sqrt(2.0))
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Brent())
@test sol.left sqrt(2.0)

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
Expand Down

0 comments on commit 6b5da4b

Please sign in to comment.