From 146dec98d2652dc3860788a995f824dff2e621e5 Mon Sep 17 00:00:00 2001 From: FHoltorf <32248677+FHoltorf@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:19:22 -0400 Subject: [PATCH] add NLsolve trust region updating scheme and change GN step to -J\fu to avoid growing ill-conditioning --- src/trustRegion.jl | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index e4d40a747..5efc01fde 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -25,6 +25,13 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows: """ Simple + """ + `RadiusUpdateSchemes.NLsolve` + + The same updating rule as in NLsolve's trust region implementation + """ + NLsolve + """ `RadiusUpdateSchemes.Hei` @@ -244,7 +251,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, p3 = convert(eltype(u), 0.0) p4 = convert(eltype(u), 0.0) ϵ = convert(eltype(u), 1.0e-8) - if radius_update_scheme === RadiusUpdateSchemes.Hei + if radius_update_scheme === RadiusUpdateSchemes.NLsolve + p1 = convert(eltype(u), 0.5) + elseif radius_update_scheme === RadiusUpdateSchemes.Hei step_threshold = convert(eltype(u), 0.0) shrink_threshold = convert(eltype(u), 0.25) expand_threshold = convert(eltype(u), 0.25) @@ -310,8 +319,9 @@ function perform_step!(cache::TrustRegionCache{true}) cache.stats.njacs += 1 end - linres = dolinsolve(alg.precs, linsolve; A = cache.H, b = _vec(cache.g), - linu = _vec(u_tmp), p, reltol = cache.abstol) + linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), # cache.H, b = _vec(cache.g), + linu = _vec(u_tmp), + p = p, reltol = cache.abstol) cache.linsolve = linres.cache cache.u_tmp .= -1 .* u_tmp dogleg!(cache) @@ -374,7 +384,7 @@ function trust_region_step!(cache::TrustRegionCache) # Compute the ratio of the actual reduction to the predicted reduction. cache.r = -(loss - cache.loss_new) / (dot(step_size, g) + dot(step_size, H, step_size) / 2) - @unpack r = cache + @unpack r = cache if radius_update_scheme === RadiusUpdateSchemes.Simple # Update the trust region radius. @@ -403,6 +413,30 @@ function trust_region_step!(cache::TrustRegionCache) cache.force_stop = true end + elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve + # accept/reject decision + if r > cache.step_threshold # accept + take_step!(cache) + cache.loss = cache.loss_new + cache.make_new_J = true + else # reject + cache.make_new_J = false + end + + # trust region update + if r < cache.shrink_threshold # default 1 // 10 + cache.trust_r *= cache.shrink_factor # default 1 // 2 + elseif r >= cache.expand_threshold # default 9 // 10 + cache.trust_r = cache.expand_factor * norm(cache.step_size) # default 2 + elseif r >= cache.p1 # default 1 // 2 + cache.trust_r = max(cache.trust_r, cache.expand_factor * norm(cache.step_size)) + end + + # convergence test + if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + cache.force_stop = true + end + elseif radius_update_scheme === RadiusUpdateSchemes.Hei if r > cache.step_threshold take_step!(cache)