Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

ITP method implementation #67

Closed
wants to merge 17 commits into from
5 changes: 3 additions & 2 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ include("dfsane.jl")
include("ad.jl")
include("halley.jl")
include("alefeld.jl")
include("itp.jl")

import PrecompileTools

Expand All @@ -65,14 +66,14 @@ PrecompileTools.@compile_workload begin
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,
T.((0.0, 2.0)),
T(2))
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld, Itp)
solve(prob_brack, alg(), abstol = T(1e-2))
end
end
end

# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, Itp

end # module
105 changes: 105 additions & 0 deletions src/itp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
```julia
Itp(; k1 = Val{1}(), k2 = Val{2}(), n0 = Val{1}())
```
ITP (Interpolate Truncate & Project)


"""

struct Itp <: AbstractBracketingAlgorithm
k1::Real
k2::Real
n0::Int
function Itp(; k1::Real = 0.007, k2::Real = 1.5, n0::Int = 10)
if k1 < 0
error("Hyper-parameter κ₁ should not be negative")
end
if n0 < 0
error("Hyper-parameter n₀ should not be negative")
end
if k2 < 1 || k2 > (1.5 + sqrt(5) / 2)
ArgumentError("Hyper-parameter κ₂ should be between 1 and 1 + ϕ where ϕ ≈ 1.618... is the golden ratio")
end
return new(k1, k2, n0)
end
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Itp,
args...; abstol = 1.0e-15,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan # a and b
fl, fr = f(left), f(right)
ϵ = abstol
if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
elseif iszero(fr)
return SciMLBase.build_solution(prob, alg, right, fr;
retcode = ReturnCode.ExactSolutionRight, left = left,
right = right)
end
#defining variables/cache
k1 = alg.k1
k2 = alg.k2
n0 = alg.n0
n_h = ceil(log2((right - left) / (2 * ϵ)))
mid = (left + right) / 2
x_f = (fr * left - fl * right) / (fr - fl)
xt = left
xp = left
r = zero(left) #minmax radius
δ = zero(left) # truncation error
σ = 1.0
ϵ_s = ϵ * 2^(n_h + n0)
i = 0 #iteration
while i <= maxiters
#mid = (left + right) / 2
r = ϵ_s - ((right - left) / 2)
δ = k1 * ((right - left)^k2)

## Interpolation step ##
x_f = (fr * left - fl * right) / (fr - fl)

## Truncation step ##
σ = sign(mid - x_f)
if δ <= abs(mid - x_f)
xt = x_f + (σ * δ)
else
xt = mid
end

## Projection step ##
if abs(xt - mid) <= r
xp = xt
else
xp = mid - (σ * r)
end

## Update ##
yp = f(xp)
if yp > 0
right = xp
fr = yp
elseif yp < 0
left = xp
fl = yp
else
left = xp
right = xp
end
i += 1
mid = (left + right) / 2
ϵ_s /= 2

if (right - left < 2 * ϵ)
return SciMLBase.build_solution(prob, alg, mid, f(mid);
retcode = ReturnCode.Success, left = left,
right = right)
end
end
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
end
26 changes: 25 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

# ITP
g = function (p)
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
sol = solve(probN, Itp())
return sol.u
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

# Alefeld
g = function (p)
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
Expand Down Expand Up @@ -242,7 +254,7 @@ end
f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
for alg in [Bisection(), Falsi(), Ridder(), Brent()]
for alg in [Bisection(), Falsi(), Ridder(), Brent(), Itp()]
global g, p
g = function (p)
probN = IntervalNonlinearProblem{false}(f, tspan, p)
Expand Down Expand Up @@ -345,6 +357,18 @@ probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Alefeld())
@test sol.u ≈ sqrt(2.0)

# Itp
sol = solve(probB, Itp())
@test sol.u ≈ sqrt(2.0)
tspan = (sqrt(2.0), 10.0)
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Itp())
@test sol.u ≈ sqrt(2.0)
tspan = (0.0, sqrt(2.0))
probB = IntervalNonlinearProblem(f, tspan)
sol = solve(probB, Itp())
@test sol.u ≈ sqrt(2.0)

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
Expand Down