Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing some new Trust region radius update schemes #159

Merged
merged 15 commits into from
Mar 21, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.5.0"
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ForwardDiff: Dual
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import EnumX
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing from the Project.toml

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i added the pkg in the env but didn't commit it yet. will do it in the next commit

import ArrayInterface
import LinearSolve
using DiffEqBase
Expand Down Expand Up @@ -59,6 +60,8 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
end
end end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt

end # module
77 changes: 51 additions & 26 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,18 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
EnumX.@enumx RadiusUpdateSchemes begin
Simple
Hei
Yuan
Bastin
end

struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
radius_update_scheme::RadiusUpdateSchemes.T
max_trust_radius::MTR
initial_trust_radius::MTR
step_threshold::MTR
Expand All @@ -98,6 +106,7 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
Expand All @@ -109,7 +118,7 @@ function TrustRegion(; chunk_size = Val{0}(),
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(max_trust_radius)
}(linsolve, precs, max_trust_radius,
}(linsolve, precs, radius_update_scheme, max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
Expand Down Expand Up @@ -138,6 +147,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
radius_update_scheme::RadiusUpdateSchemes.T
trust_r::trustType
max_trust_r::trustType
step_threshold::suType
Expand All @@ -161,7 +171,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
jac_config::JC, iter::Int,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
prob::probType, radius_update_scheme::RadiusUpdateSchemes.T, trust_r::trustType,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
Expand All @@ -178,7 +188,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
abstol, prob, radius_update_scheme, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
Expand Down Expand Up @@ -238,6 +248,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
Expand All @@ -264,7 +275,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
Expand All @@ -290,6 +301,7 @@ function perform_step!(cache::TrustRegionCache{true})
cache.u_tmp .= u .+ cache.step_size
f(cache.fu_new, cache.u_tmp, p)

@unpack radius_update_scheme = cache
trust_region_step!(cache)
return nothing
end
Expand All @@ -312,42 +324,55 @@ function perform_step!(cache::TrustRegionCache{false})
cache.u_tmp = u .+ cache.step_size
cache.fu_new = f(cache.u_tmp, p)

@unpack radius_update_scheme = cache
trust_region_step!(cache)
return nothing
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
if radius_update_scheme == RadiusUpdateSchemes.Simple
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
# Update the trust region radius.
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end
# Update the trust region radius.
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
else
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end

elseif radius_update_scheme == RadiusUpdateSchemes.Hei


elseif radius_update_scheme == RadiusUpdateSchemes.Yuan


elseif radius_update_scheme == RadiusUpdateSchemes.Bastin
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

cache.make_new_J = true
else
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
end

Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@ end
function get_loss(fu)
return norm(fu)^2 / 2
end

function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
if (r >= c2)
return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π
else
return (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β))
end
end