Skip to content

Commit

Permalink
Merge pull request #1 from CCsimon123/TrustRegion
Browse files Browse the repository at this point in the history
TrustRegion fix
  • Loading branch information
daviehh authored Jan 23, 2023
2 parents fd8c37e + 9ce3d1e commit 8107ad0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 34 deletions.
88 changes: 55 additions & 33 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
```
Expand Down Expand Up @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
Expand Down Expand Up @@ -141,6 +141,11 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
prob::probType
trust_r::trustType
max_trust_r::trustType
step_threshold::trustType
shrink_threshold::trustType
expand_threshold::trustType
shrink_factor::trustType
expand_factor::trustType
loss::floatType
loss_new::floatType
H::jType
Expand All @@ -158,10 +163,12 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
max_trust_r::trustType, loss::floatType,
loss_new::floatType, H::jType, g::resType,
shrink_counter::Int, step_size::uType, u_tmp::uType,
fu_new::resType, make_new_J::Bool,
max_trust_r::trustType, step_threshold::trustType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
loss::floatType, loss_new::floatType, H::jType,
g::resType, shrink_counter::Int, step_size::uType,
u_tmp::uType, fu_new::resType, make_new_J::Bool,
r::floatType) where {iip, fType, algType, uType,
resType, pType, INType,
tolType, probType, ufType, L,
Expand All @@ -171,7 +178,10 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r, loss,
abstol, prob, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
Expand Down Expand Up @@ -228,25 +238,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
expand_threshold = convert(eltype(u), alg.expand_threshold)
shrink_factor = convert(eltype(u), alg.shrink_factor)
expand_factor = convert(eltype(u), alg.expand_factor)
# Set default trust region radius if not specified
u_elType = uType <: Number ? uType : eltype(u)
max_trust_radius = u_elType(alg.max_trust_radius)
initial_trust_radius = u_elType(alg.initial_trust_radius)
if iszero(max_trust_radius)
max_trust_radius = max(norm(fu), maximum(u) - minimum(u))
max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u)))
end
if iszero(initial_trust_radius)
initial_trust_radius = max_trust_radius / 11
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
end

loss_new = loss
H = ArrayInterfaceCore.undefmatrix(u)
g = zero(fu)
shrink_counter = 0
step_size = zero(u)
fu_new = zero(fu)
make_new_J = true
r = loss

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss, H, zero(fu), 0, zero(u),
u_tmp, zero(fu), true,
loss)
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
make_new_J, r)
end

function perform_step!(cache::TrustRegionCache{true})
Expand Down Expand Up @@ -295,27 +317,27 @@ function perform_step!(cache::TrustRegionCache{false})
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, u_tmp, step_size, g, H, loss, alg, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < alg.shrink_threshold
cache.trust_r *= alg.shrink_factor
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > alg.step_threshold
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > alg.expand_threshold
cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r)
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0)
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 120
@test (@ballocated solve!(solver)) < 200

# AD Tests
using ForwardDiff
Expand Down

0 comments on commit 8107ad0

Please sign in to comment.