Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing some new Trust region radius update schemes #159

Merged
merged 15 commits into from
Mar 21, 2023
104 changes: 93 additions & 11 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR, RUS} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
radius_update_scheme::RUS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
radius_update_scheme::RUS
radius_update_scheme::RadiusUpdateSchemes.T

It doesn't need to be parametric.,

max_trust_radius::MTR
initial_trust_radius::MTR
step_threshold::MTR
Expand All @@ -94,10 +95,33 @@ struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
max_shrink_times::Int
end

struct RadiusUpdate{B}
simple::B
hei::B
yuan::B
bastin::B
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct RadiusUpdate{B}
simple::B
hei::B
yuan::B
bastin::B
struct RadiusUpdate
simple::Bool
hei::Bool
yuan::Bool
bastin::Bool

end

function RadiusUpdate(;simple::Bool = Val{true}(), # 3 different radius update schemes
hei::Bool = Val{false}(),
yuan::Bool = Val{false}(),
bastin::Bool = Val{false}())
if simple
return RadiusUpdate{Bool}(true, false, false, false)
elseif hei
return RadiusUpdate{Bool}(false, true, false, false)
elseif yuan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. If it's going to be non-dispatching logic, use an EnumX.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, I am sorry I misinterpreted the use of Val, but now it's clear. I have changed the overall approach.

return RadiusUpdate{Bool}(false, false, true, false)
elseif bastin
return RadiusUpdate{Bool}(false, false, false, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return RadiusUpdate{Bool}(true, false, false, false)
elseif hei
return RadiusUpdate{Bool}(false, true, false, false)
elseif yuan
return RadiusUpdate{Bool}(false, false, true, false)
elseif bastin
return RadiusUpdate{Bool}(false, false, false, true)
return RadiusUpdate{Bool}(true, false, false, false)
elseif hei
return RadiusUpdate{Bool}(false, true, false, false)
elseif yuan
return RadiusUpdate(false, false, true, false)
elseif bastin
return RadiusUpdate(false, false, false, true)

end
end

function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme = RadiusUpdate(simple = true), #defaults to conventional radius update
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
Expand All @@ -108,8 +132,8 @@ function TrustRegion(; chunk_size = Val{0}(),
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(max_trust_radius)
}(linsolve, precs, max_trust_radius,
_unwrap_val(concrete_jac), typeof(max_trust_radius), typeof(radius_update_scheme)
}(linsolve, precs, radius_update_scheme, max_trust_radius,
initial_trust_radius,
step_threshold,
shrink_threshold,
Expand All @@ -120,7 +144,7 @@ function TrustRegion(; chunk_size = Val{0}(),
end

mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
INType, tolType, probType, ufType, L, jType, JC, floatType, radType,
trustType, suType, su2Type, tmpType}
f::fType
alg::algType
Expand All @@ -138,6 +162,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
radius_update_scheme::radType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

trust_r::trustType
max_trust_r::trustType
step_threshold::suType
Expand All @@ -161,7 +186,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
jac_config::JC, iter::Int,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
prob::probType, radius_update_scheme::radType, trust_r::trustType,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
Expand All @@ -172,13 +197,13 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
resType, pType, INType,
tolType, probType, ufType, L,
jType, JC, floatType, trustType,
suType, su2Type, tmpType}
suType, su2Type, tmpType, radType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
radType, trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
abstol, prob, radius_update_scheme, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
Expand Down Expand Up @@ -238,6 +263,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
Expand All @@ -264,7 +290,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
Expand All @@ -290,7 +316,16 @@ function perform_step!(cache::TrustRegionCache{true})
cache.u_tmp .= u .+ cache.step_size
f(cache.fu_new, cache.u_tmp, p)

trust_region_step!(cache)
@unpack radius_update_scheme = cache
if radius_update_scheme.simple
trust_region_step!(cache)
elseif radius_update_scheme.hei

elseif radius_update_scheme.yuan

elseif radius_update_scheme.bastin

end
return nothing
end

Expand All @@ -312,10 +347,57 @@ function perform_step!(cache::TrustRegionCache{false})
cache.u_tmp = u .+ cache.step_size
cache.fu_new = f(cache.u_tmp, p)

trust_region_step!(cache)
@unpack radius_update_scheme = cache
if radius_update_scheme.simple
trust_region_step!(cache)
elseif radius_update_scheme.hei

elseif radius_update_scheme.yuan

elseif radius_update_scheme.bastin

end
return nothing
end

function trust_region_step!(cache::TrustRegionCache)
## MODIFIED SCHEME GOES HERE


# @unpack fu_new, step_size, g, H, loss, max_trust_r = cache
# cache.loss_new = get_loss(fu_new)

# # Compute the ratio of the actual reduction to the predicted reduction.
# cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
# @unpack r = cache

# # Update the trust region radius.
# if r < cache.shrink_threshold
# cache.trust_r *= cache.shrink_factor
# cache.shrink_counter += 1
# else
# cache.shrink_counter = 0
# end
# if r > cache.step_threshold
# take_step!(cache)
# cache.loss = cache.loss_new

# # Update the trust region radius.
# if r > cache.expand_threshold
# cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
# end

# cache.make_new_J = true
# else
# # No need to make a new J, no step was taken, so we try again with a smaller trust_r
# cache.make_new_J = false
# end

# if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
# cache.force_stop = true
# end
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)
Expand Down