diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index c29fe63d6..33f5895d7 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -23,6 +23,7 @@ include("broyden.jl") include("klement.jl") include("trustRegion.jl") include("ridder.jl") +include("brent.jl") include("ad.jl") import SnoopPrecompile @@ -44,12 +45,13 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64) =# prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2)) - for alg in (Bisection, Falsi, Ridder) + for alg in (Bisection, Falsi, Ridder, Brent) solve(prob_brack, alg(), abstol = T(1e-2)) end end end # DiffEq styled algorithms -export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion +export Bisection, Brent, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, + SimpleTrustRegion end # module diff --git a/lib/SimpleNonlinearSolve/src/brent.jl b/lib/SimpleNonlinearSolve/src/brent.jl new file mode 100644 index 000000000..99f645f6a --- /dev/null +++ b/lib/SimpleNonlinearSolve/src/brent.jl @@ -0,0 +1,114 @@ +""" +`Brent()` + +A non-allocating Brent method + +""" +struct Brent <: AbstractBracketingAlgorithm end + +function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; + maxiters = 1000, + kwargs...) + f = Base.Fix2(prob.f, prob.p) + a, b = prob.tspan + fa, fb = f(a), f(b) + ϵ = eps(convert(typeof(fa), 1.0)) + + if iszero(fa) + return SciMLBase.build_solution(prob, alg, a, fa; + retcode = ReturnCode.ExactSolutionLeft, left = a, + right = b) + end + if abs(fa) < abs(fb) + c = b + b = a + a = c + tmp = fa + fa = fb + fb = tmp + end + + c = a + d = c + i = 1 + cond = true + if !iszero(fb) + while i < maxiters + fc = f(c) + if fa != fc && fb != fc + # Inverse quadratic interpolation + s = a * fb * fc / ((fa - fb) * (fa - fc)) + + b * fa * fc / ((fb - fa) * (fb - fc)) + + c * fa * fb / ((fc - fa) * (fc - fb)) + else + # Secant method + s = b - fb * (b - a) / (fb - fa) + end + if (s < min((3 * a + b) / 4, b) || s > max((3 * a + b) / 4, b)) || + (cond && abs(s - b) ≥ abs(b - c) / 2) || + (!cond && abs(s - b) ≥ abs(c - d) / 2) || + (cond && abs(b - c) ≤ ϵ) || + (!cond && abs(c - d) ≤ ϵ) + # Bisection method + s = (a + b) / 2 + (s == a || s == b) && + return SciMLBase.build_solution(prob, alg, a, fa; + retcode = ReturnCode.FloatingPointLimit, + left = a, right = b) + cond = true + else + cond = false + end + fs = f(s) + if iszero(fs) + if b < a + a = b + fa = fb + end + b = s + fb = fs + break + end + if fa * fs < 0 + d = c + c = b + b = s + fb = fs + else + a = s + fa = fs + end + if abs(fa) < abs(fb) + d = c + c = b + b = a + a = c + fc = fb + fb = fa + fa = fc + end + i += 1 + end + end + + while i < maxiters + c = (a + b) / 2 + if (c == a || c == b) + return SciMLBase.build_solution(prob, alg, a, fa; + retcode = ReturnCode.FloatingPointLimit, + left = a, right = b) + end + fc = f(c) + if iszero(fc) + b = c + fb = fc + else + a = c + fa = fc + end + i += 1 + end + + return SciMLBase.build_solution(prob, alg, a, fa; retcode = ReturnCode.MaxIters, + left = a, right = b) +end diff --git a/lib/SimpleNonlinearSolve/test/basictests.jl b/lib/SimpleNonlinearSolve/test/basictests.jl index 017569227..fa15e798e 100644 --- a/lib/SimpleNonlinearSolve/test/basictests.jl +++ b/lib/SimpleNonlinearSolve/test/basictests.jl @@ -121,10 +121,22 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) end +# Brent +g = function (p) + probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p) + sol = solve(probN, Brent()) + return sol.left +end + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) t = (p) -> [sqrt(p[2] / p[1])] p = [0.9, 50.0] -for alg in [Bisection(), Falsi(), Ridder()] +for alg in [Bisection(), Falsi(), Ridder(), Brent()] global g, p g = function (p) probN = IntervalNonlinearProblem{false}(f, tspan, p) @@ -200,6 +212,18 @@ probB = IntervalNonlinearProblem(f, tspan) sol = solve(probB, Ridder()) @test sol.left ≈ sqrt(2.0) +# Brent +sol = solve(probB, Brent()) +@test sol.left ≈ sqrt(2.0) +tspan = (sqrt(2.0), 10.0) +probB = IntervalNonlinearProblem(f, tspan) +sol = solve(probB, Brent()) +@test sol.left ≈ sqrt(2.0) +tspan = (0.0, sqrt(2.0)) +probB = IntervalNonlinearProblem(f, tspan) +sol = solve(probB, Brent()) +@test sol.left ≈ sqrt(2.0) + # Garuntee Tests for Bisection f = function (u, p) if u < 2.0