Skip to content

Commit

Permalink
add NLsolve trust region updating scheme and change GN step to -J\fu …
Browse files Browse the repository at this point in the history
…to avoid growing ill-conditioning
  • Loading branch information
FHoltorf committed Sep 26, 2023
1 parent 6f3556e commit 146dec9
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
"""
Simple

"""
`RadiusUpdateSchemes.NLsolve`
The same updating rule as in NLsolve's trust region implementation
"""
NLsolve

"""
`RadiusUpdateSchemes.Hei`
Expand Down Expand Up @@ -244,7 +251,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p3 = convert(eltype(u), 0.0)
p4 = convert(eltype(u), 0.0)
ϵ = convert(eltype(u), 1.0e-8)
if radius_update_scheme === RadiusUpdateSchemes.Hei
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
p1 = convert(eltype(u), 0.5)
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
step_threshold = convert(eltype(u), 0.0)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
Expand Down Expand Up @@ -310,8 +319,9 @@ function perform_step!(cache::TrustRegionCache{true})
cache.stats.njacs += 1
end

linres = dolinsolve(alg.precs, linsolve; A = cache.H, b = _vec(cache.g),
linu = _vec(u_tmp), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), # cache.H, b = _vec(cache.g),
linu = _vec(u_tmp),
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.u_tmp .= -1 .* u_tmp
dogleg!(cache)
Expand Down Expand Up @@ -374,7 +384,7 @@ function trust_region_step!(cache::TrustRegionCache)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (dot(step_size, g) + dot(step_size, H, step_size) / 2)
@unpack r = cache
@unpack r = cache

if radius_update_scheme === RadiusUpdateSchemes.Simple
# Update the trust region radius.
Expand Down Expand Up @@ -403,6 +413,30 @@ function trust_region_step!(cache::TrustRegionCache)
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve
# accept/reject decision
if r > cache.step_threshold # accept
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else # reject
cache.make_new_J = false
end

# trust region update
if r < cache.shrink_threshold # default 1 // 10
cache.trust_r *= cache.shrink_factor # default 1 // 2
elseif r >= cache.expand_threshold # default 9 // 10
cache.trust_r = cache.expand_factor * norm(cache.step_size) # default 2
elseif r >= cache.p1 # default 1 // 2
cache.trust_r = max(cache.trust_r, cache.expand_factor * norm(cache.step_size))
end

# convergence test
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.Hei
if r > cache.step_threshold
take_step!(cache)
Expand Down

0 comments on commit 146dec9

Please sign in to comment.