Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing some new Trust region radius update schemes #159

Merged
merged 15 commits into from
Mar 21, 2023
180 changes: 168 additions & 12 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,25 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
EnumX.@enumx RadiusUpdateSchemes begin
Simple
Hei
Yuan
Bastin
end

struct RadiusUpdate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't they mutually exclusive? We should just use enum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the structs can be deleted.

simple::Bool
hei::Bool
yuan::Bool
bastin::Bool
end

struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR, RUS} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
radius_update_scheme::RUS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
radius_update_scheme::RUS
radius_update_scheme::RadiusUpdateSchemes.T

It doesn't need to be parametric.,

max_trust_radius::MTR
initial_trust_radius::MTR
step_threshold::MTR
Expand All @@ -94,10 +109,25 @@ struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
max_shrink_times::Int
end

function RadiusUpdate(;hei::Bool = false,
yuan::Bool = false,
bastin::Bool = false)
if !(hei || yuan || bastin)
return RadiusUpdate(true, false, false, false)
elseif hei
return RadiusUpdate(false, true, false, false)
elseif yuan
return RadiusUpdate(false, false, true, false)
elseif bastin
return RadiusUpdate(false, false, false, true)
end
end

function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme = RadiusUpdate(), #defaults to conventional radius update
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
Expand All @@ -108,8 +138,8 @@ function TrustRegion(; chunk_size = Val{0}(),
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(max_trust_radius)
}(linsolve, precs, max_trust_radius,
_unwrap_val(concrete_jac), typeof(max_trust_radius), typeof(radius_update_scheme)
}(linsolve, precs, radius_update_scheme, max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
Expand All @@ -120,7 +150,7 @@ function TrustRegion(; chunk_size = Val{0}(),
end

mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
INType, tolType, probType, ufType, L, jType, JC, floatType, radType,
trustType, suType, su2Type, tmpType}
f::fType
alg::algType
Expand All @@ -138,6 +168,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
radius_update_scheme::radType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

trust_r::trustType
max_trust_r::trustType
step_threshold::suType
Expand All @@ -161,7 +192,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
jac_config::JC, iter::Int,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
prob::probType, radius_update_scheme::radType, trust_r::trustType,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
Expand All @@ -172,13 +203,13 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
resType, pType, INType,
tolType, probType, ufType, L,
jType, JC, floatType, trustType,
suType, su2Type, tmpType}
suType, su2Type, tmpType, radType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
radType, trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
abstol, prob, radius_update_scheme, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
Expand Down Expand Up @@ -238,6 +269,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
Expand All @@ -264,7 +296,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
Expand All @@ -290,7 +322,16 @@ function perform_step!(cache::TrustRegionCache{true})
cache.u_tmp .= u .+ cache.step_size
f(cache.fu_new, cache.u_tmp, p)

trust_region_step!(cache)
@unpack radius_update_scheme = cache
if radius_update_scheme.simple
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Simple)))
elseif radius_update_scheme.hei
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Hei)))
elseif radius_update_scheme.yuan
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Yuan)))
elseif radius_update_scheme.bastin
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Bastin)))
end
return nothing
end

Expand All @@ -312,11 +353,20 @@ function perform_step!(cache::TrustRegionCache{false})
cache.u_tmp = u .+ cache.step_size
cache.fu_new = f(cache.u_tmp, p)

trust_region_step!(cache)
@unpack radius_update_scheme = cache
if radius_update_scheme.simple
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Simple)))
elseif radius_update_scheme.hei
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Hei)))
elseif radius_update_scheme.yuan
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Yuan)))
elseif radius_update_scheme.bastin
trust_region_step!(cache, Val(Int(RadiusUpdateSchemes.Bastin)))
end
return nothing
end

function trust_region_step!(cache::TrustRegionCache)
function trust_region_step!(cache::TrustRegionCache, ::Val{0}) # conventional radius update scheme
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a single dispatch and branch

Copy link
Member Author

@yash2798 yash2798 Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function trust_region_step!(cache::TrustRegionCache, ::Val{0}) # conventional radius update scheme
function trust_region_step!(cache::TrustRegionCache)
@unpack radius_update_scheme = cache
if radius_update_scheme == RadiusUpdateSchemes.Simple
#method 1
elseif radius_update_scheme == RadiusUpdateSchemes.Hei
#method 2
...
end
end

do you mean like this in a single method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! also if I need some helper functions, like in Hei's scheme, there is a monotonic function used which depends on the actual:predicted change ratio. So these kind of helper functions can go in utils.jl, right?

@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

Expand Down Expand Up @@ -351,6 +401,112 @@ function trust_region_step!(cache::TrustRegionCache)
end
end

function trust_region_step!(cache::TrustRegionCache, ::Val{1}) # hei's radius update scheme

# @unpack fu_new, step_size, g, H, loss, max_trust_r = cache
# cache.loss_new = get_loss(fu_new)

# # Compute the ratio of the actual reduction to the predicted reduction.
# cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
# @unpack r = cache

# # Update the trust region radius.
# if r < cache.shrink_threshold
# cache.trust_r *= cache.shrink_factor
# cache.shrink_counter += 1
# else
# cache.shrink_counter = 0
# end
# if r > cache.step_threshold
# take_step!(cache)
# cache.loss = cache.loss_new

# # Update the trust region radius.
# if r > cache.expand_threshold
# cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
# end

# cache.make_new_J = true
# else
# # No need to make a new J, no step was taken, so we try again with a smaller trust_r
# cache.make_new_J = false
# end

# if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
# cache.force_stop = true
# end
end

function trust_region_step!(cache::TrustRegionCache, ::Val{2}) # yuan's radius update scheme
# @unpack fu_new, step_size, g, H, loss, max_trust_r = cache
# cache.loss_new = get_loss(fu_new)

# # Compute the ratio of the actual reduction to the predicted reduction.
# cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
# @unpack r = cache

# # Update the trust region radius.
# if r < cache.shrink_threshold
# cache.trust_r *= cache.shrink_factor
# cache.shrink_counter += 1
# else
# cache.shrink_counter = 0
# end
# if r > cache.step_threshold
# take_step!(cache)
# cache.loss = cache.loss_new

# # Update the trust region radius.
# if r > cache.expand_threshold
# cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
# end

# cache.make_new_J = true
# else
# # No need to make a new J, no step was taken, so we try again with a smaller trust_r
# cache.make_new_J = false
# end

# if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
# cache.force_stop = true
# end
end

function trust_region_step!(cache::TrustRegionCache, ::Val{3}) # bastin's radius update scheme
# @unpack fu_new, step_size, g, H, loss, max_trust_r = cache
# cache.loss_new = get_loss(fu_new)

# # Compute the ratio of the actual reduction to the predicted reduction.
# cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
# @unpack r = cache

# # Update the trust region radius.
# if r < cache.shrink_threshold
# cache.trust_r *= cache.shrink_factor
# cache.shrink_counter += 1
# else
# cache.shrink_counter = 0
# end
# if r > cache.step_threshold
# take_step!(cache)
# cache.loss = cache.loss_new

# # Update the trust region radius.
# if r > cache.expand_threshold
# cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
# end

# cache.make_new_J = true
# else
# # No need to make a new J, no step was taken, so we try again with a smaller trust_r
# cache.make_new_J = false
# end

# if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
# cache.force_stop = true
# end
end

function dogleg!(cache::TrustRegionCache)
@unpack u_tmp, trust_r = cache

Expand Down