Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
Merge pull request #28 from SciML/ChrisRackauckas-patch-1
Browse files Browse the repository at this point in the history
TrustRegion -> SimpleTrustRegion and specialize the number types
  • Loading branch information
ChrisRackauckas authored Jan 17, 2023
2 parents 2b4cfa5 + 632910b commit 17fdeee
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end

for alg in (TrustRegion(10.0),)
for alg in (SimpleTrustRegion(10.0),)
solve(prob_no_brack, alg, abstol = T(1e-2))
end

Expand All @@ -53,6 +53,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
end end

# DiffEq styled algorithms
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, TrustRegion
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, SimpleTrustRegion

end # module
57 changes: 30 additions & 27 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
```julia
TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(), diff_type = Val{:forward})
```
Expand Down Expand Up @@ -49,36 +49,39 @@ solver
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
"""
struct TrustRegion{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
max_trust_radius::Number
initial_trust_radius::Number
step_threshold::Number
shrink_threshold::Number
expand_threshold::Number
shrink_factor::Number
expand_factor::Number
struct SimpleTrustRegion{T, CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
max_trust_radius::T
initial_trust_radius::T
step_threshold::T
shrink_threshold::T
expand_threshold::T
shrink_factor::T
expand_factor::T
max_shrink_times::Int
function TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
initial_trust_radius::Number = max_trust_radius / 11,
step_threshold::Number = 0.1,
shrink_threshold::Number = 0.25,
expand_threshold::Number = 0.75,
shrink_factor::Number = 0.25,
expand_factor::Number = 2.0,
max_shrink_times::Int = 32)
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
SciMLBase._unwrap_val(diff_type)}(max_trust_radius, initial_trust_radius,
step_threshold,
shrink_threshold, expand_threshold,
shrink_factor,
expand_factor, max_shrink_times)
function SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
initial_trust_radius::Number = max_trust_radius / 11,
step_threshold::Number = 0.1,
shrink_threshold::Number = 0.25,
expand_threshold::Number = 0.75,
shrink_factor::Number = 0.25,
expand_factor::Number = 2.0,
max_shrink_times::Int = 32)
new{typeof(initial_trust_radius), SciMLBase._unwrap_val(chunk_size),
SciMLBase._unwrap_val(autodiff), SciMLBase._unwrap_val(diff_type)}(max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
expand_threshold,
shrink_factor,
expand_factor,
max_shrink_times)
end
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::TrustRegion, args...; abstol = nothing,
alg::SimpleTrustRegion, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
Expand All @@ -94,7 +97,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
max_shrink_times = alg.max_shrink_times

if SciMLBase.isinplace(prob)
error("TrustRegion currently only supports out-of-place nonlinear problems")
error("SimpleTrustRegion currently only supports out-of-place nonlinear problems")
end

atol = abstol !== nothing ? abstol :
Expand Down
42 changes: 21 additions & 21 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
end

# TrustRegion
# SimpleTrustRegion
function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, TrustRegion(10.0)))
sol = (solve(probN, SimpleTrustRegion(10.0)))
end

sol = benchmark_scalar(sf, csu0)
Expand All @@ -69,7 +69,7 @@ using ForwardDiff
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, alg, abstol = 1e-9)
Expand All @@ -85,7 +85,7 @@ end
# Scalar
f, u0 = (u, p) -> u * u - p, 1.0
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, alg)
Expand Down Expand Up @@ -127,7 +127,7 @@ for alg in [Bisection(), Falsi()]
end

for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
SimpleTrustRegion(10.0)]
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
Expand All @@ -144,8 +144,8 @@ probN = NonlinearProblem(f, u0)

@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, TrustRegion(10.0)).u[end] sqrt(2.0)
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, SimpleTrustRegion(10.0)).u[end] sqrt(2.0)
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, Broyden()).u[end] sqrt(2.0)
@test solve(probN, Klement()).u[end] sqrt(2.0)

Expand All @@ -159,9 +159,9 @@ for u0 in [1.0, [1, 1.0]]
@test solve(probN, SimpleNewtonRaphson()).u sol
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol

@test solve(probN, TrustRegion(10.0)).u sol
@test solve(probN, TrustRegion(10.0)).u sol
@test solve(probN, TrustRegion(10.0; autodiff = false)).u sol
@test solve(probN, SimpleTrustRegion(10.0)).u sol
@test solve(probN, SimpleTrustRegion(10.0)).u sol
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u sol

@test solve(probN, Broyden()).u sol

Expand Down Expand Up @@ -205,7 +205,7 @@ sol = solve(probB, Bisection(; exact_left = true, exact_right = true); immutable
@test f(sol.right, nothing) >= 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0

# Test that `TrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
# Test that `SimpleTrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
global g, f
f = (u, p) -> 0.010000000000000002 .+
Expand All @@ -219,15 +219,15 @@ f = (u, p) -> 0.010000000000000002 .+
.-p
g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion(100.0))
sol = solve(probN, SimpleTrustRegion(100.0))
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u = g(p)
f(u, p)
@test all(f(u, p) .< 1e-10)

# Test kwars in `TrustRegion`
# Test kwars in `SimpleTrustRegion`
max_trust_radius = [10.0, 100.0, 1000.0]
initial_trust_radius = [10.0, 1.0, 0.1]
step_threshold = [0.0, 0.01, 0.25]
Expand All @@ -242,14 +242,14 @@ list_of_options = zip(max_trust_radius, initial_trust_radius, step_threshold,
expand_factor, max_shrink_times)
for options in list_of_options
local probN, sol, alg
alg = TrustRegion(options[1];
initial_trust_radius = options[2],
step_threshold = options[3],
shrink_threshold = options[4],
expand_threshold = options[5],
shrink_factor = options[6],
expand_factor = options[7],
max_shrink_times = options[8])
alg = SimpleTrustRegion(options[1];
initial_trust_radius = options[2],
step_threshold = options[3],
shrink_threshold = options[4],
expand_threshold = options[5],
shrink_factor = options[6],
expand_factor = options[7],
max_shrink_times = options[8])

probN = NonlinearProblem(f, u0, p)
sol = solve(probN, alg)
Expand Down

0 comments on commit 17fdeee

Please sign in to comment.