Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

TrustRegion -> SimpleTrustRegion and specialize the number types #28

Merged
merged 4 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end

for alg in (TrustRegion(10.0),)
for alg in (SimpleTrustRegion(10.0),)
solve(prob_no_brack, alg, abstol = T(1e-2))
end

Expand All @@ -53,6 +53,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
end end

# DiffEq styled algorithms
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, TrustRegion
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, SimpleTrustRegion

end # module
57 changes: 30 additions & 27 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
```julia
TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(), diff_type = Val{:forward})
```

Expand Down Expand Up @@ -49,36 +49,39 @@ solver
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
"""
struct TrustRegion{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
max_trust_radius::Number
initial_trust_radius::Number
step_threshold::Number
shrink_threshold::Number
expand_threshold::Number
shrink_factor::Number
expand_factor::Number
struct SimpleTrustRegion{T, CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
max_trust_radius::T
initial_trust_radius::T
step_threshold::T
shrink_threshold::T
expand_threshold::T
shrink_factor::T
expand_factor::T
max_shrink_times::Int
function TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
initial_trust_radius::Number = max_trust_radius / 11,
step_threshold::Number = 0.1,
shrink_threshold::Number = 0.25,
expand_threshold::Number = 0.75,
shrink_factor::Number = 0.25,
expand_factor::Number = 2.0,
max_shrink_times::Int = 32)
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
SciMLBase._unwrap_val(diff_type)}(max_trust_radius, initial_trust_radius,
step_threshold,
shrink_threshold, expand_threshold,
shrink_factor,
expand_factor, max_shrink_times)
function SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
initial_trust_radius::Number = max_trust_radius / 11,
step_threshold::Number = 0.1,
shrink_threshold::Number = 0.25,
expand_threshold::Number = 0.75,
shrink_factor::Number = 0.25,
expand_factor::Number = 2.0,
max_shrink_times::Int = 32)
new{typeof(initial_trust_radius), SciMLBase._unwrap_val(chunk_size),
SciMLBase._unwrap_val(autodiff), SciMLBase._unwrap_val(diff_type)}(max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
expand_threshold,
shrink_factor,
expand_factor,
max_shrink_times)
end
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::TrustRegion, args...; abstol = nothing,
alg::SimpleTrustRegion, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
Expand All @@ -94,7 +97,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
max_shrink_times = alg.max_shrink_times

if SciMLBase.isinplace(prob)
error("TrustRegion currently only supports out-of-place nonlinear problems")
error("SimpleTrustRegion currently only supports out-of-place nonlinear problems")
end

atol = abstol !== nothing ? abstol :
Expand Down
42 changes: 21 additions & 21 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
end

# TrustRegion
# SimpleTrustRegion
function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, TrustRegion(10.0)))
sol = (solve(probN, SimpleTrustRegion(10.0)))
end

sol = benchmark_scalar(sf, csu0)
Expand All @@ -69,7 +69,7 @@ using ForwardDiff
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, alg, abstol = 1e-9)
Expand All @@ -85,7 +85,7 @@ end
# Scalar
f, u0 = (u, p) -> u * u - p, 1.0
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, alg)
Expand Down Expand Up @@ -127,7 +127,7 @@ for alg in [Bisection(), Falsi()]
end

for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
Expand All @@ -144,8 +144,8 @@ probN = NonlinearProblem(f, u0)

@test solve(probN, SimpleNewtonRaphson()).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(10.0)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleTrustRegion(10.0)).u[end] ≈ sqrt(2.0)
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u[end] ≈ sqrt(2.0)
@test solve(probN, Broyden()).u[end] ≈ sqrt(2.0)
@test solve(probN, Klement()).u[end] ≈ sqrt(2.0)

Expand All @@ -159,9 +159,9 @@ for u0 in [1.0, [1, 1.0]]
@test solve(probN, SimpleNewtonRaphson()).u ≈ sol
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u ≈ sol

@test solve(probN, TrustRegion(10.0)).u ≈ sol
@test solve(probN, TrustRegion(10.0)).u ≈ sol
@test solve(probN, TrustRegion(10.0; autodiff = false)).u ≈ sol
@test solve(probN, SimpleTrustRegion(10.0)).u ≈ sol
@test solve(probN, SimpleTrustRegion(10.0)).u ≈ sol
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u ≈ sol

@test solve(probN, Broyden()).u ≈ sol

Expand Down Expand Up @@ -205,7 +205,7 @@ sol = solve(probB, Bisection(; exact_left = true, exact_right = true); immutable
@test f(sol.right, nothing) >= 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0

# Test that `TrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
# Test that `SimpleTrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
global g, f
f = (u, p) -> 0.010000000000000002 .+
Expand All @@ -219,15 +219,15 @@ f = (u, p) -> 0.010000000000000002 .+
.-p
g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion(100.0))
sol = solve(probN, SimpleTrustRegion(100.0))
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u = g(p)
f(u, p)
@test all(f(u, p) .< 1e-10)

# Test kwars in `TrustRegion`
# Test kwars in `SimpleTrustRegion`
max_trust_radius = [10.0, 100.0, 1000.0]
initial_trust_radius = [10.0, 1.0, 0.1]
step_threshold = [0.0, 0.01, 0.25]
Expand All @@ -242,14 +242,14 @@ list_of_options = zip(max_trust_radius, initial_trust_radius, step_threshold,
expand_factor, max_shrink_times)
for options in list_of_options
local probN, sol, alg
alg = TrustRegion(options[1];
initial_trust_radius = options[2],
step_threshold = options[3],
shrink_threshold = options[4],
expand_threshold = options[5],
shrink_factor = options[6],
expand_factor = options[7],
max_shrink_times = options[8])
alg = SimpleTrustRegion(options[1];
initial_trust_radius = options[2],
step_threshold = options[3],
shrink_threshold = options[4],
expand_threshold = options[5],
shrink_factor = options[6],
expand_factor = options[7],
max_shrink_times = options[8])

probN = NonlinearProblem(f, u0, p)
sol = solve(probN, alg)
Expand Down