diff --git a/lib/SimpleNonlinearSolve/src/bisection.jl b/lib/SimpleNonlinearSolve/src/bisection.jl index 673f77067..24db4adcc 100644 --- a/lib/SimpleNonlinearSolve/src/bisection.jl +++ b/lib/SimpleNonlinearSolve/src/bisection.jl @@ -20,17 +20,21 @@ function Bisection(; exact_left = false, exact_right = false) end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...; - maxiters = 1000, + maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])), kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan fl, fr = f(left), f(right) - if iszero(fl) return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left = left, right = right) end + if iszero(fr) + return SciMLBase.build_solution(prob, alg, right, fr; + retcode = ReturnCode.ExactSolutionRight, left = left, + right = right) + end i = 1 if !iszero(fr) @@ -41,6 +45,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args... retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fm) right = mid break @@ -63,6 +72,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args... retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fm) right = mid fr = fm diff --git a/lib/SimpleNonlinearSolve/src/brent.jl b/lib/SimpleNonlinearSolve/src/brent.jl index 1cedad134..47e5495f0 100644 --- a/lib/SimpleNonlinearSolve/src/brent.jl +++ b/lib/SimpleNonlinearSolve/src/brent.jl @@ -7,7 +7,7 @@ A non-allocating Brent method struct Brent <: AbstractBracketingAlgorithm end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; - maxiters = 1000, + maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])), kwargs...) f = Base.Fix2(prob.f, prob.p) a, b = prob.tspan @@ -18,6 +18,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; return SciMLBase.build_solution(prob, alg, a, fa; retcode = ReturnCode.ExactSolutionLeft, left = a, right = b) + elseif iszero(fb) + return SciMLBase.build_solution(prob, alg, b, fb; + retcode = ReturnCode.ExactSolutionRight, left = a, + right = b) end if abs(fa) < abs(fb) c = b @@ -60,6 +64,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; cond = false end fs = f(s) + if abs((b - a) / 2) < abstol + return SciMLBase.build_solution(prob, alg, s, fs; + retcode = ReturnCode.Success, + left = a, right = b) + end if iszero(fs) if b < a a = b @@ -99,6 +108,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; left = a, right = b) end fc = f(c) + if abs((b - a) / 2) < abstol + return SciMLBase.build_solution(prob, alg, c, fc; + retcode = ReturnCode.Success, + left = a, right = b) + end if iszero(fc) b = c fb = fc diff --git a/lib/SimpleNonlinearSolve/src/falsi.jl b/lib/SimpleNonlinearSolve/src/falsi.jl index cce11a811..de1079beb 100644 --- a/lib/SimpleNonlinearSolve/src/falsi.jl +++ b/lib/SimpleNonlinearSolve/src/falsi.jl @@ -4,7 +4,7 @@ struct Falsi <: AbstractBracketingAlgorithm end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; - maxiters = 1000, + maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])), kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan @@ -14,6 +14,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left = left, right = right) + elseif iszero(fr) + return SciMLBase.build_solution(prob, alg, right, fr; + retcode = ReturnCode.ExactSolutionRight, left = left, + right = right) end i = 1 @@ -32,6 +36,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; break end fm = f(mid) + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fm) right = mid break @@ -54,6 +63,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fm) right = mid fr = fm diff --git a/lib/SimpleNonlinearSolve/src/itp.jl b/lib/SimpleNonlinearSolve/src/itp.jl index fa390aa45..f6688381c 100644 --- a/lib/SimpleNonlinearSolve/src/itp.jl +++ b/lib/SimpleNonlinearSolve/src/itp.jl @@ -59,7 +59,7 @@ struct ITP{T} <: AbstractBracketingAlgorithm end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, - args...; abstol = 1.0e-15, + args...; abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])), maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan # a and b @@ -111,6 +111,12 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, xp = mid - (σ * r) end + if abs((left - right) / 2) < ϵ + return SciMLBase.build_solution(prob, alg, mid, f(mid); + retcode = ReturnCode.Success, + left = left, right = right) + end + ## Update ## tmin, tmax = minmax(left, right) xp >= tmax && (xp = prevfloat(tmax)) diff --git a/lib/SimpleNonlinearSolve/src/ridder.jl b/lib/SimpleNonlinearSolve/src/ridder.jl index 62b5a931a..ce95a178a 100644 --- a/lib/SimpleNonlinearSolve/src/ridder.jl +++ b/lib/SimpleNonlinearSolve/src/ridder.jl @@ -7,7 +7,7 @@ A non-allocating ridder method struct Ridder <: AbstractBracketingAlgorithm end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; - maxiters = 1000, + maxiters = 1000, abstol = min(eps(prob.tspan[1]), eps(prob.tspan[2])), kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan @@ -17,6 +17,10 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left = left, right = right) + elseif iszero(fr) + return SciMLBase.build_solution(prob, alg, right, fr; + retcode = ReturnCode.ExactSolutionRight, left = left, + right = right) end xo = oftype(left, Inf) @@ -37,6 +41,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; x = mid + (mid - left) * sign(fl - fr) * fm / s fx = f(x) xo = x + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fx) right = x fr = fx @@ -66,6 +75,11 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; retcode = ReturnCode.FloatingPointLimit, left = left, right = right) fm = f(mid) + if abs((right - left) / 2) < abstol + return SciMLBase.build_solution(prob, alg, mid, fm; + retcode = ReturnCode.Success, + left = left, right = right) + end if iszero(fm) right = mid fr = fm diff --git a/lib/SimpleNonlinearSolve/test/basictests.jl b/lib/SimpleNonlinearSolve/test/basictests.jl index 6ab9e56d8..34468be58 100644 --- a/lib/SimpleNonlinearSolve/test/basictests.jl +++ b/lib/SimpleNonlinearSolve/test/basictests.jl @@ -373,6 +373,34 @@ probB = IntervalNonlinearProblem(f, tspan) sol = solve(probB, ITP()) @test sol.u ≈ sqrt(2.0) +# Tolerance tests for Interval methods +f, tspan = (u, p) -> u .* u .- 2.0, (1.0, 10.0) +probB = IntervalNonlinearProblem(f, tspan) +tols = [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7] +ϵ = eps(1.0) #least possible tol for all methods + +for atol in tols + sol = solve(probB, Bisection(), abstol = atol) + @test abs(sol.u - sqrt(2)) < atol + @test abs(sol.u - sqrt(2)) > ϵ #test that the solution is not calculated upto max precision + sol = solve(probB, Falsi(), abstol = atol) + @test abs(sol.u - sqrt(2)) < atol + @test abs(sol.u - sqrt(2)) > ϵ + sol = solve(probB, ITP(), abstol = atol) + @test abs(sol.u - sqrt(2)) < atol + @test abs(sol.u - sqrt(2)) > ϵ +end + +tols = [0.1] # Ridder and Brent converge rapidly so as we lower tolerance below 0.01, it converges with max precision to the solution +for atol in tols + sol = solve(probB, Ridder(), abstol = atol) + @test abs(sol.u - sqrt(2)) < atol + @test abs(sol.u - sqrt(2)) > ϵ + sol = solve(probB, Brent(), abstol = atol) + @test abs(sol.u - sqrt(2)) < atol + @test abs(sol.u - sqrt(2)) > ϵ +end + # Garuntee Tests for Bisection f = function (u, p) if u < 2.0