Skip to content

Commit

Permalink
Merge pull request #125 from SciML/ap/tr_nlsolve_update
Browse files Browse the repository at this point in the history
Add NLsolve update rule to SimpleTrustRegion
  • Loading branch information
ChrisRackauckas authored Feb 9, 2024
2 parents 559c51b + 6ac95bd commit 0a86a9e
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 27 deletions.
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.3.2"
version = "1.4.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
68 changes: 52 additions & 16 deletions lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
SimpleTrustRegion(; autodiff = AutoForwardDiff(), max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0, step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25, expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25, expand_factor::Real = 2.0, max_shrink_times::Int = 32)
SimpleTrustRegion(; autodiff = AutoForwardDiff(), max_trust_radius = 0.0,
initial_trust_radius = 0.0, step_threshold = nothing,
shrink_threshold = nothing, expand_threshold = nothing,
shrink_factor = 0.25, expand_factor = 2.0, max_shrink_times::Int = 32,
nlsolve_update_rule = Val(false))
A low-overhead implementation of a trust-region solver. This method is non-allocating on
scalar and static array problems.
Expand Down Expand Up @@ -36,17 +37,22 @@ scalar and static array problems.
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
- `nlsolve_update_rule`: If set to `Val(true)`, updates the trust region radius using the
update rule from NLSolve.jl. Defaults to `Val(false)`. If set to `Val(true)`, few of the
radius update parameters -- `step_threshold = 0.05`, `expand_threshold = 0.9`, and
`shrink_factor = 0.5` -- have different defaults.
"""
@kwdef @concrete struct SimpleTrustRegion <: AbstractNewtonAlgorithm
autodiff = nothing
max_trust_radius = 0.0
initial_trust_radius = 0.0
step_threshold = 0.0001
shrink_threshold = 0.25
expand_threshold = 0.75
shrink_factor = 0.25
shrink_threshold = nothing
expand_threshold = nothing
shrink_factor = nothing
expand_factor = 2.0
max_shrink_times::Int = 32
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
Expand All @@ -57,14 +63,27 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
Δₘₐₓ = T(alg.max_trust_radius)
Δ = T(alg.initial_trust_radius)
η₁ = T(alg.step_threshold)
η₂ = T(alg.shrink_threshold)
η₃ = T(alg.expand_threshold)
t₁ = T(alg.shrink_factor)
if alg.shrink_threshold === nothing
η₂ = _unwrap_val(alg.nlsolve_update_rule) ? T(0.05) : T(0.25)
else
η₂ = T(alg.shrink_threshold)
end
if alg.expand_threshold === nothing
η₃ = _unwrap_val(alg.nlsolve_update_rule) ? T(0.9) : T(0.75)
else
η₃ = T(alg.expand_threshold)
end
if alg.shrink_factor === nothing
t₁ = _unwrap_val(alg.nlsolve_update_rule) ? T(0.5) : T(0.25)
else
t₁ = T(alg.shrink_factor)
end
t₂ = T(alg.expand_factor)
max_shrink_times = alg.max_shrink_times
autodiff = __get_concrete_autodiff(prob, alg.autodiff)

fx = _get_fx(prob, x)
norm_fx = norm(fx)
@bb xo = copy(x)
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
Expand All @@ -73,10 +92,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
termination_condition)

# Set default trust region radius if not specified by user.
Δₘₐₓ == 0 && (Δₘₐₓ = max(norm(fx), maximum(x) - minimum(x)))
Δ == 0 &&= Δₘₐₓ / 11)
Δₘₐₓ == 0 && (Δₘₐₓ = max(norm_fx, maximum(x) - minimum(x)))
if Δ == 0
if _unwrap_val(alg.nlsolve_update_rule)
norm_x = norm(x)
Δ = T(ifelse(norm_x > 0, norm_x, 1))
else
Δ = T(Δₘₐₓ / 11)
end
end

fₖ = 0.5 * norm(fx)^2
fₖ = 0.5 * norm_fx^2
H = ∇f' * ∇f
g = _restructure(x, ∇f' * _vec(fx))
shrink_counter = 0
Expand All @@ -87,7 +113,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
@bb= copy(x)
dogleg_cache = (; δsd, δN_δsd, δN)

for k in 1:maxiters
for _ in 1:maxiters
# Solve the trust region subproblem.
δ = dogleg_method!!(dogleg_cache, ∇f, fx, g, Δ)
@bb @. x = xo + δ
Expand All @@ -107,7 +133,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
Δ = t₁ * Δ
shrink_counter += 1
shrink_counter > max_shrink_times && return build_solution(prob, alg, x, fx;
retcode = ReturnCode.ConvergenceFailure)
retcode = ReturnCode.ShrinkThresholdExceeded)
end

if r η₁
Expand All @@ -121,12 +147,22 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)

# Update the trust region radius.
(r > η₃) && (norm(δ) Δ) &&= min(t₂ * Δ, Δₘₐₓ))
if !_unwrap_val(alg.nlsolve_update_rule) && r > η₃
Δ = min(t₂ * Δ, Δₘₐₓ)
end
fₖ = fₖ₊₁

@bb H = transpose(∇f) × ∇f
@bb g = transpose(∇f) × vec(fx)
end

if _unwrap_val(alg.nlsolve_update_rule)
if r > η₃
Δ = t₂ * norm(δ)
elseif r > 0.5
Δ = max(Δ, t₂ * norm(δ))
end
end
end

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
Expand Down
6 changes: 4 additions & 2 deletions lib/SimpleNonlinearSolve/test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ end
end

@testcase "SimpleTrustRegion 23 Test Problems" begin
alg_ops = (SimpleTrustRegion(),)
alg_ops = (SimpleTrustRegion(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [3, 6, 15, 16, 21]
broken_tests[alg_ops[1]] = [3, 15, 16, 21]
broken_tests[alg_ops[2]] = [15, 16]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down
10 changes: 7 additions & 3 deletions lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ end

# --- SimpleNewtonRaphson tests ---

@testcase "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion)
@testcase "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion,
(args...; kwargs...) -> SimpleTrustRegion(args...; nlsolve_update_rule = Val(true),
kwargs...))
@testset "AutoDiff: $(_nameof(autodiff))" for autodiff in (AutoFiniteDiff(),
AutoForwardDiff(), AutoPolyesterForwardDiff())
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down Expand Up @@ -110,7 +112,8 @@ end
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

@testcase "$(alg)" for alg in (SimpleDFSane(), SimpleTrustRegion(), SimpleHalley())
@testcase "$(alg)" for alg in (SimpleDFSane(), SimpleTrustRegion(), SimpleHalley(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)))
sol = benchmark_nlsolve_oop(newton_fails, u0, p; solver = alg)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(newton_fails(sol.u, p)) .< 1e-9)
Expand All @@ -122,7 +125,8 @@ end
## SimpleDFSane needs to allocate a history vector
@testcase "Allocation Checks: $(_nameof(alg))" for alg in (SimpleNewtonRaphson(),
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleLimitedMemoryBroyden(),
SimpleTrustRegion(), SimpleDFSane(), SimpleBroyden(; linesearch = Val(true)),
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleDFSane(), SimpleBroyden(; linesearch = Val(true)),
SimpleLimitedMemoryBroyden(; linesearch = Val(true)))
@check_allocs nlsolve(prob, alg) = SciMLBase.solve(prob, alg; abstol = 1e-9)

Expand Down
7 changes: 4 additions & 3 deletions lib/SimpleNonlinearSolve/test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ f!(du, u, p) = du .= u .* u .- 2

@testset "Solving on GPUs" begin
@testcase "$(alg)" for alg in (SimpleNewtonRaphson(), SimpleDFSane(),
SimpleTrustRegion(), SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(),
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(),
SimpleHalley(), SimpleBroyden(; linesearch = Val(true)),
SimpleLimitedMemoryBroyden(; linesearch = Val(true)))
# Static Arrays
Expand Down Expand Up @@ -44,8 +45,8 @@ end
prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])

@testcase "$(alg)" for alg in (SimpleNewtonRaphson(), SimpleDFSane(),
SimpleTrustRegion(), SimpleBroyden(),
SimpleLimitedMemoryBroyden(), SimpleKlement(), SimpleHalley(),
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleBroyden(), SimpleLimitedMemoryBroyden(), SimpleKlement(), SimpleHalley(),
SimpleBroyden(; linesearch = Val(true)),
SimpleLimitedMemoryBroyden(; linesearch = Val(true)))
@test begin
Expand Down
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/test/forward_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ __compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:oop}) = true
__compatible(::SimpleHalley, ::Val{:iip}) = false

@testcase "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),
SimpleTrustRegion(), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))

@testset "Scalar AD" begin
Expand Down
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/test/matrix_resizing_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ vecprob = NonlinearProblem(ff, vec(u0), p)
prob = NonlinearProblem(ff, u0, p)

@testcase "$(alg)" for alg in (SimpleKlement(), SimpleBroyden(), SimpleNewtonRaphson(),
SimpleDFSane(), SimpleLimitedMemoryBroyden(; threshold = Val(2)), SimpleTrustRegion())
SimpleDFSane(), SimpleLimitedMemoryBroyden(; threshold = Val(2)), SimpleTrustRegion(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)))
@test vec(solve(prob, alg).u) solve(vecprob, alg).u
end

0 comments on commit 0a86a9e

Please sign in to comment.