Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing some new Trust region radius update schemes #159

Merged
merged 15 commits into from
Mar 21, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.5.0"
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ForwardDiff: Dual
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import EnumX
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing from the Project.toml

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i added the pkg in the env but didn't commit it yet. will do it in the next commit

import ArrayInterface
import LinearSolve
using DiffEqBase
Expand Down Expand Up @@ -59,6 +60,8 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
end
end end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt

end # module
141 changes: 113 additions & 28 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,18 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
EnumX.@enumx RadiusUpdateSchemes begin
Simple
Hei
Yuan
Bastin
end

struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
radius_update_scheme::RadiusUpdateSchemes.T
max_trust_radius::MTR
initial_trust_radius::MTR
step_threshold::MTR
Expand All @@ -98,6 +106,7 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
Expand All @@ -109,7 +118,7 @@ function TrustRegion(; chunk_size = Val{0}(),
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(max_trust_radius)
}(linsolve, precs, max_trust_radius,
}(linsolve, precs, radius_update_scheme, max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
Expand Down Expand Up @@ -138,6 +147,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
radius_update_scheme::RadiusUpdateSchemes.T
trust_r::trustType
max_trust_r::trustType
step_threshold::suType
Expand All @@ -155,20 +165,26 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
fu_new::resType
make_new_J::Bool
r::floatType
p1::floatType
p2::floatType
p3::floatType
p4::floatType
ϵ::floatType

function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
uf::ufType, linsolve::L, J::jType,
jac_config::JC, iter::Int,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
prob::probType, radius_update_scheme::RadiusUpdateSchemes.T, trust_r::trustType,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
loss::floatType, loss_new::floatType, H::jType,
g::resType, shrink_counter::Int, step_size::su2Type,
u_tmp::tmpType, fu_new::resType, make_new_J::Bool,
r::floatType) where {iip, fType, algType, uType,
r::floatType, p1::floatType, p2::floatType, p3::floatType,
p4::floatType, ϵ::floatType) where {iip, fType, algType, uType,
resType, pType, INType,
tolType, probType, ufType, L,
jType, JC, floatType, trustType,
Expand All @@ -178,13 +194,13 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
abstol, prob, radius_update_scheme, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
make_new_J, r, p1, p2, p3, p4, ϵ)
end
end

Expand Down Expand Up @@ -238,6 +254,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
Expand All @@ -262,13 +279,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
make_new_J = true
r = loss

# Parameters for the Schemes
p1 = convert(eltype(u), 0.0)
p2 = convert(eltype(u), 0.0)
p3 = convert(eltype(u), 0.0)
p4 = convert(eltype(u), 0.0)
ϵ = convert(eltype(u), 1.0e-8)
if radius_update_scheme === RadiusUpdateSchemes.Hei
step_threshold = convert(eltype(u), 0.0)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
p1 = convert(eltype(u), 5.0) # M
p2 = convert(eltype(u), 0.1) # β
p3 = convert(eltype(u), 0.15) # γ1
p4 = convert(eltype(u), 0.15) # γ2
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
step_threshold = convert(eltype(u), 0.0001)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
p1 = convert(eltype(u), 2.0) # μ
p2 = convert(eltype(u), 1/6) # c5
p3 = convert(eltype(u), 6.0) # c6
p4 = convert(eltype(u), 0.0)
end

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
make_new_J, r)
make_new_J, r, p1, p2, p3, p4, ϵ)
end

function perform_step!(cache::TrustRegionCache{true})
Expand All @@ -289,7 +330,6 @@ function perform_step!(cache::TrustRegionCache{true})
# Compute the potentially new u
cache.u_tmp .= u .+ cache.step_size
f(cache.fu_new, cache.u_tmp, p)

trust_region_step!(cache)
return nothing
end
Expand All @@ -311,43 +351,88 @@ function perform_step!(cache::TrustRegionCache{false})
# Compute the potentially new u
cache.u_tmp = u .+ cache.step_size
cache.fu_new = f(cache.u_tmp, p)

trust_region_step!(cache)
return nothing
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
if radius_update_scheme === RadiusUpdateSchemes.Simple
# Update the trust region radius.
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
else
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.Hei
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error for not implemented

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, like a TypeError when the radius_update_scheme type doesn't match RadiusUpdateSchemes.T ?

if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else
cache.make_new_J = false
end
# Hei's radius update scheme
@unpack shrink_threshold, p1, p2, p3, p4 = cache
if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) < cache.trust_r
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
end
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined

# Update the trust region radius.
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
cache.force_stop = true
end


elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
if r < cache.shrink_threshold
cache.p1 = cache.p2 * cache.p1
cache.shrink_counter += 1
elseif r >= cache.expand_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
cache.p1 = cache.p3 * cache.p1
end
@unpack p1, fu, f, J = cache
#cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?

if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
else
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
end
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
cache.force_stop = true
end

#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
end
end

Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@ end
function get_loss(fu)
return norm(fu)^2 / 2
end

function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
if (r >= c2)
return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π
else
return (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β))
end
end