diff --git a/Project.toml b/Project.toml index 88956b412..1dddc6ad5 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ ForwardDiff = "0.10.3" RecursiveArrayTools = "2" Reexport = "0.2" Setfield = "0.7" -StaticArrays = "1.0" +StaticArrays = "0.12,1.0" UnPack = "1.0" julia = "1" diff --git a/src/scalar.jl b/src/scalar.jl index ccdc61ab1..10f907c7b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -53,7 +53,7 @@ function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V, end # avoid ambiguities -for Alg in [Bisection, Falsi] +for Alg in [Bisection] @eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) @@ -110,3 +110,61 @@ function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kw return BracketingSolution(left, right, MAXITERS_EXCEED) end + +function solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...) + f = Base.Fix2(prob.f, prob.p) + left, right = prob.u0 + fl, fr = f(left), f(right) + + if iszero(fl) + return BracketingSolution(left, right, EXACT_SOLUTION_LEFT) + end + + i = 1 + if !iszero(fr) + while i < maxiters + if nextfloat_tdir(left, prob.u0...) == right + return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + end + mid = (fr * left - fl * right) / (fr - fl) + for i in 1:10 + mid = max(left, prevfloat_tdir(mid, prob.u0...)) + end + if mid == right || mid == left + break + end + fm = f(mid) + if iszero(fm) + right = mid + break + end + if sign(fl) == sign(fm) + fl = fm + left = mid + else + fr = fm + right = mid + end + i += 1 + end + end + + while i < maxiters + mid = (left + right) / 2 + (mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT) + fm = f(mid) + if iszero(fm) + right = mid + fr = fm + elseif sign(fm) == sign(fl) + left = mid + fl = fm + else + right = mid + fr = fm + end + i += 1 + end + + return BracketingSolution(left, right, MAXITERS_EXCEED) +end diff --git a/test/runtests.jl b/test/runtests.jl index 0fb0c359a..8ab368153 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,7 @@ end # Scalar f, u0 = (u, p) -> u * u - p, 1.0 +# NewtonRaphson g = function (p) probN = NonlinearProblem{false}(f, oftype(p, u0), p) sol = solve(probN, NewtonRaphson()) @@ -69,6 +70,19 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) end +u0 = (1.0, 20.0) +# Falsi +g = function (p) + probN = NonlinearProblem{false}(f, typeof(p).(u0), p) + sol = solve(probN, Falsi()) + return sol.left +end + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) +end + f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) t = (p) -> [sqrt(p[2] / p[1])] p = [0.9, 50.0]