Skip to content

Commit

Permalink
Merge pull request #77 from yash2798/ys/tolerance
Browse files Browse the repository at this point in the history
Tolerance fix for interval solvers
  • Loading branch information
ChrisRackauckas authored Sep 12, 2023
2 parents 167db26 + 35247e6 commit 221bb41
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 6 deletions.
18 changes: 16 additions & 2 deletions lib/SimpleNonlinearSolve/src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@ function Bisection(; exact_left = false, exact_right = false)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
maxiters = 1000,
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end
if iszero(fr)
return SciMLBase.build_solution(prob, alg, right, fr;
retcode = ReturnCode.ExactSolutionRight, left = left,
right = right)
end

i = 1
if !iszero(fr)
Expand All @@ -41,6 +45,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fm)
right = mid
break
Expand All @@ -63,6 +72,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fm)
right = mid
fr = fm
Expand Down
16 changes: 15 additions & 1 deletion lib/SimpleNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A non-allocating Brent method
struct Brent <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
maxiters = 1000,
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
Expand All @@ -18,6 +18,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
return SciMLBase.build_solution(prob, alg, a, fa;
retcode = ReturnCode.ExactSolutionLeft, left = a,
right = b)
elseif iszero(fb)
return SciMLBase.build_solution(prob, alg, b, fb;
retcode = ReturnCode.ExactSolutionRight, left = a,
right = b)
end
if abs(fa) < abs(fb)
c = b
Expand Down Expand Up @@ -60,6 +64,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
cond = false
end
fs = f(s)
if abs((b - a) / 2) < abstol
return SciMLBase.build_solution(prob, alg, s, fs;
retcode = ReturnCode.Success,
left = a, right = b)
end
if iszero(fs)
if b < a
a = b
Expand Down Expand Up @@ -99,6 +108,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
left = a, right = b)
end
fc = f(c)
if abs((b - a) / 2) < abstol
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = a, right = b)
end
if iszero(fc)
b = c
fb = fc
Expand Down
16 changes: 15 additions & 1 deletion lib/SimpleNonlinearSolve/src/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
struct Falsi <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
maxiters = 1000,
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
Expand All @@ -14,6 +14,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
elseif iszero(fr)
return SciMLBase.build_solution(prob, alg, right, fr;
retcode = ReturnCode.ExactSolutionRight, left = left,
right = right)
end

i = 1
Expand All @@ -32,6 +36,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
break
end
fm = f(mid)
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fm)
right = mid
break
Expand All @@ -54,6 +63,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fm)
right = mid
fr = fm
Expand Down
8 changes: 7 additions & 1 deletion lib/SimpleNonlinearSolve/src/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct ITP{T} <: AbstractBracketingAlgorithm
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP,
args...; abstol = 1.0e-15,
args...; abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan # a and b
Expand Down Expand Up @@ -111,6 +111,12 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP,
xp = mid -* r)
end

if abs((left - right) / 2) < ϵ
return SciMLBase.build_solution(prob, alg, mid, f(mid);
retcode = ReturnCode.Success,
left = left, right = right)
end

## Update ##
tmin, tmax = minmax(left, right)
xp >= tmax && (xp = prevfloat(tmax))
Expand Down
16 changes: 15 additions & 1 deletion lib/SimpleNonlinearSolve/src/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A non-allocating ridder method
struct Ridder <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
maxiters = 1000,
maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])),
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
Expand All @@ -17,6 +17,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
elseif iszero(fr)
return SciMLBase.build_solution(prob, alg, right, fr;
retcode = ReturnCode.ExactSolutionRight, left = left,
right = right)
end

xo = oftype(left, Inf)
Expand All @@ -37,6 +41,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
x = mid + (mid - left) * sign(fl - fr) * fm / s
fx = f(x)
xo = x
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fx)
right = x
fr = fx
Expand Down Expand Up @@ -66,6 +75,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if abs((right - left) / 2) < abstol
return SciMLBase.build_solution(prob, alg, mid, fm;
retcode = ReturnCode.Success,
left = left, right = right)
end
if iszero(fm)
right = mid
fr = fm
Expand Down
28 changes: 28 additions & 0 deletions lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,34 @@ probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, ITP())
@test sol.u sqrt(2.0)

# Tolerance tests for Interval methods
f, tspan = (u, p) -> u .* u .- 2.0, (1.0, 10.0)
probB = IntervalNonlinearProblem(f, tspan)
tols = [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7]
ϵ = eps(1.0) #least possible tol for all methods

for atol in tols
sol = solve(probB, Bisection(), abstol = atol)
@test abs(sol.u - sqrt(2)) < atol
@test abs(sol.u - sqrt(2)) > ϵ #test that the solution is not calculated upto max precision
sol = solve(probB, Falsi(), abstol = atol)
@test abs(sol.u - sqrt(2)) < atol
@test abs(sol.u - sqrt(2)) > ϵ
sol = solve(probB, ITP(), abstol = atol)
@test abs(sol.u - sqrt(2)) < atol
@test abs(sol.u - sqrt(2)) > ϵ
end

tols = [0.1] # Ridder and Brent converge rapidly so as we lower tolerance below 0.01, it converges with max precision to the solution
for atol in tols
sol = solve(probB, Ridder(), abstol = atol)
@test abs(sol.u - sqrt(2)) < atol
@test abs(sol.u - sqrt(2)) > ϵ
sol = solve(probB, Brent(), abstol = atol)
@test abs(sol.u - sqrt(2)) < atol
@test abs(sol.u - sqrt(2)) > ϵ
end

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
Expand Down

0 comments on commit 221bb41

Please sign in to comment.