diff --git a/lib/SimpleNonlinearSolve/src/itp.jl b/lib/SimpleNonlinearSolve/src/itp.jl index 648208b80..1f1efbafe 100644 --- a/lib/SimpleNonlinearSolve/src/itp.jl +++ b/lib/SimpleNonlinearSolve/src/itp.jl @@ -78,9 +78,9 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, k1 = alg.k1 k2 = alg.k2 n0 = alg.n0 - n_h = ceil(log2((right - left) / (2 * ϵ))) + n_h = ceil(log2(abs(right - left) / (2 * ϵ))) mid = (left + right) / 2 - x_f = (fr * left - fl * right) / (fr - fl) + x_f = left + (right - left) * (fl/(fl - fr)) xt = left xp = left r = zero(left) #minmax radius @@ -89,12 +89,12 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, ϵ_s = ϵ * 2^(n_h + n0) i = 0 #iteration while i <= maxiters - #mid = (left + right) / 2 - r = ϵ_s - ((right - left) / 2) - δ = k1 * ((right - left)^k2) + span = abs(right - left) + r = ϵ_s - (span / 2) + δ = k1 * (span^k2) ## Interpolation step ## - x_f = (fr * left - fl * right) / (fr - fl) + x_f = left + (right - left) * (fl/(fl - fr)) ## Truncation step ## σ = sign(mid - x_f) @@ -112,6 +112,9 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, end ## Update ## + tmin, tmax = minmax(left, right) + xp >= tmax && (xp = prevfloat(tmax)) + xp <= tmin && (xp = nextfloat(tmin)) yp = f(xp) yps = yp * sign(fr) if yps > 0 @@ -121,16 +124,17 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, left = xp fl = yp else - left = xp - right = xp + return SciMLBase.build_solution(prob, alg, xp, yps; + retcode = ReturnCode.Success, left = xp, + right = xp) end i += 1 mid = (left + right) / 2 ϵ_s /= 2 - if (right - left < 2 * ϵ) - return SciMLBase.build_solution(prob, alg, mid, f(mid); - retcode = ReturnCode.Success, left = left, + if nextfloat_tdir(left, prob.tspan...) == right + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = ReturnCode.FloatingPointLimit, left = left, right = right) end end diff --git a/lib/SimpleNonlinearSolve/test/basictests.jl b/lib/SimpleNonlinearSolve/test/basictests.jl index 9dc681f3a..6ab9e56d8 100644 --- a/lib/SimpleNonlinearSolve/test/basictests.jl +++ b/lib/SimpleNonlinearSolve/test/basictests.jl @@ -540,18 +540,19 @@ for alg in (SimpleNewtonRaphson(), SimpleTrustRegion()) @test abs.(sol.u) ≈ sqrt.(p) end -# Flipped signs test +# Flipped signs & reversed tspan test for bracketing algorithms f1(u, p) = u * u - p f2(u, p) = p - u * u -for Alg in (Alefeld, Bisection, Falsi, Brent, ITP, Ridder) - alg = Alg() +for alg in (Alefeld(), Bisection(), Falsi(), Brent(), ITP(), Ridder()) for p in 1:4 inp1 = IntervalNonlinearProblem(f1, (1.0, 2.0), p) inp2 = IntervalNonlinearProblem(f2, (1.0, 2.0), p) - sol = solve(inp1, alg) - @test abs.(sol.u) ≈ sqrt.(p) - sol = solve(inp2, alg) - @test abs.(sol.u) ≈ sqrt.(p) + inp3 = IntervalNonlinearProblem(f1, (2.0, 1.0), p) + inp4 = IntervalNonlinearProblem(f2, (2.0, 1.0), p) + @test abs.(solve(inp1, alg).u) ≈ sqrt.(p) + @test abs.(solve(inp2, alg).u) ≈ sqrt.(p) + @test abs.(solve(inp3, alg).u) ≈ sqrt.(p) + @test abs.(solve(inp4, alg).u) ≈ sqrt.(p) end end