From 0f9858ca34f06218070c2a07c553b0716c3aa566 Mon Sep 17 00:00:00 2001 From: David AW Barton Date: Wed, 29 Mar 2023 16:51:29 +0100 Subject: [PATCH 1/3] Fix typo in default argument --- src/raphson.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/raphson.jl b/src/raphson.jl index bfec2830a..08a156190 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -199,7 +199,7 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache) retcode = cache.retcode) end -function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u0; p = cache.p, +function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p, abstol = cache.abstol, maxiters = cache.maxiters) where {iip} cache.p = p if iip From c414a8fcfdd7b57c2a283dc9e2946d2f308e60bd Mon Sep 17 00:00:00 2001 From: David AW Barton Date: Wed, 29 Mar 2023 17:02:04 +0100 Subject: [PATCH 2/3] Added reinit! for TrustRegionCache --- src/trustRegion.jl | 36 +++++++++++++++++++++++++++++++----- test/basictests.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index fc1662fc0..4f80ae783 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -363,7 +363,7 @@ function trust_region_step!(cache::TrustRegionCache) cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2) @unpack r = cache - if radius_update_scheme === RadiusUpdateSchemes.Simple + if radius_update_scheme === RadiusUpdateSchemes.Simple # Update the trust region radius. if r < cache.shrink_threshold cache.trust_r *= cache.shrink_factor @@ -389,13 +389,13 @@ function trust_region_step!(cache::TrustRegionCache) if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol cache.force_stop = true end - + elseif radius_update_scheme === RadiusUpdateSchemes.Hei - if r > cache.step_threshold + if r > cache.step_threshold take_step!(cache) cache.loss = cache.loss_new cache.make_new_J = true - else + else cache.make_new_J = false end # Hei's radius update scheme @@ -427,7 +427,7 @@ function trust_region_step!(cache::TrustRegionCache) else cache.make_new_J = false end - + if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined cache.force_stop = true end @@ -489,3 +489,29 @@ function SciMLBase.solve!(cache::TrustRegionCache) SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu; retcode = cache.retcode) end + +function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p, + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.u, u0) + cache.f(cache.fu, cache.u, p) + else + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter + cache.u = u0 + cache.fu = cache.f(cache.u, p) + end + cache.abstol = abstol + cache.maxiters = maxiters + cache.iter = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + cache.make_new_J = true + cache.loss = get_loss(cache.fu) + cache.shrink_counter = 0 + cache.trust_r = convert(eltype(cache.u), cache.alg.initial_trust_radius) + if iszero(cache.trust_r) + cache.trust_r = convert(eltype(cache.u), cache.max_trust_r / 11) + end + return cache +end diff --git a/test/basictests.jl b/test/basictests.jl index ebaa80916..b6726798e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -263,6 +263,37 @@ end @test gnewton(p) ≈ [sqrt(p[2] / p[1])] @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) +# Iterator interface +f = (u, p) -> u * u - p +g = function (p_range) + probN = NonlinearProblem{false}(f, 0.5, p_range[begin]) + cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, cache.u; p = p) + sol = solve!(cache) + sols[i] = sol.u + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + +f = (res, u, p) -> (res[begin] = u[1] * u[1] - p) +g = function (p_range) + probN = NonlinearProblem{true}(f, [0.5], p_range[begin]) + cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, [cache.u[1]]; p = p) + sol = solve!(cache) + sols[i] = sol.u[1] + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + # Error Checks f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] probN = NonlinearProblem(f, u0) From 160dd8454d4170221b15dab3103261a0bb7e3cca Mon Sep 17 00:00:00 2001 From: David AW Barton Date: Wed, 29 Mar 2023 17:44:30 +0100 Subject: [PATCH 3/3] Revert accidental whitespace changes --- src/trustRegion.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 4f80ae783..b84776dc9 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -363,7 +363,7 @@ function trust_region_step!(cache::TrustRegionCache) cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2) @unpack r = cache - if radius_update_scheme === RadiusUpdateSchemes.Simple + if radius_update_scheme === RadiusUpdateSchemes.Simple # Update the trust region radius. if r < cache.shrink_threshold cache.trust_r *= cache.shrink_factor @@ -389,13 +389,13 @@ function trust_region_step!(cache::TrustRegionCache) if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol cache.force_stop = true end - + elseif radius_update_scheme === RadiusUpdateSchemes.Hei - if r > cache.step_threshold + if r > cache.step_threshold take_step!(cache) cache.loss = cache.loss_new cache.make_new_J = true - else + else cache.make_new_J = false end # Hei's radius update scheme @@ -427,7 +427,7 @@ function trust_region_step!(cache::TrustRegionCache) else cache.make_new_J = false end - + if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined cache.force_stop = true end