Skip to content

Commit

Permalink
Merge pull request #37 from CCsimon123/main
Browse files Browse the repository at this point in the history
Implementation of Ridder
  • Loading branch information
ChrisRackauckas authored Jan 29, 2023
2 parents 84e2dd1 + 2959eda commit 883b789
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
5 changes: 3 additions & 2 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include("raphson.jl")
include("broyden.jl")
include("klement.jl")
include("trustRegion.jl")
include("ridder.jl")
include("ad.jl")

import SnoopPrecompile
Expand All @@ -43,12 +44,12 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
=#

prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
for alg in (Bisection, Falsi)
for alg in (Bisection, Falsi, Ridder)
solve(prob_brack, alg(), abstol = T(1e-2))
end
end end

# DiffEq styled algorithms
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, SimpleTrustRegion
export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion

end # module
81 changes: 81 additions & 0 deletions lib/SimpleNonlinearSolve/src/ridder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
`Ridder()`
A non-allocating ridder method
"""
struct Ridder <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
maxiters = 1000,
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end

xo = oftype(left, Inf)
i = 1
if !iszero(fr)
while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
s = sqrt(fm^2 - fl * fr)
iszero(s) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.Failure,
left = left, right = right)
x = mid + (mid - left) * sign(fl - fr) * fm / s
fx = f(x)
xo = x
if iszero(fx)
right = x
fr = fx
break
end
if sign(fx) != sign(fm)
left = mid
fl = fm
right = x
fr = fx
elseif sign(fx) != sign(fl)
right = x
fr = fx
else
@assert sign(fx) != sign(fr)
left = x
fl = fx
end
i += 1
end
end

while i < maxiters
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
fr = fm
else
left = mid
fl = fm
end
i += 1
end

return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
end
26 changes: 25 additions & 1 deletion lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,22 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

# Ridder
g = function (p)
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
sol = solve(probN, Ridder())
return sol.left
end

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
for alg in [Bisection(), Falsi()]
for alg in [Bisection(), Falsi(), Ridder()]
global g, p
g = function (p)
probN = IntervalNonlinearProblem{false}(f, tspan, p)
Expand Down Expand Up @@ -176,6 +188,18 @@ sol = solve(probB, Falsi())
sol = solve(probB, Bisection())
@test sol.left sqrt(2.0)

# Ridder
sol = solve(probB, Ridder())
@test sol.left sqrt(2.0)
tspan = (sqrt(2.0), 10.0)
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Ridder())
@test sol.left sqrt(2.0)
tspan = (0.0, sqrt(2.0))
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Ridder())
@test sol.left sqrt(2.0)

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
Expand Down

0 comments on commit 883b789

Please sign in to comment.