From fd8c37e573148fd10a06132f47675c44cd754814 Mon Sep 17 00:00:00 2001 From: daviehh <25255906+daviehh@users.noreply.github.com> Date: Sun, 22 Jan 2023 22:19:00 -0500 Subject: [PATCH 1/8] u eltype cast for TrustRegion --- src/trustRegion.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 0c19dfaf5..127785515 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -229,14 +229,16 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip)) # Set default trust region radius if not specified - max_trust_radius = alg.max_trust_radius - initial_trust_radius = alg.initial_trust_radius - if max_trust_radius == 0.0 + u_elType = uType <: Number ? uType : eltype(u) + max_trust_radius = u_elType(alg.max_trust_radius) + initial_trust_radius = u_elType(alg.initial_trust_radius) + if iszero(max_trust_radius) max_trust_radius = max(norm(fu), maximum(u) - minimum(u)) end - if initial_trust_radius == 0.0 + if iszero(initial_trust_radius) initial_trust_radius = max_trust_radius / 11 end + H = ArrayInterfaceCore.undefmatrix(u) return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config, From fd96bb077045ebbfe37fddef57410a8931f7e600 Mon Sep 17 00:00:00 2001 From: Simon Carlson Date: Mon, 23 Jan 2023 10:48:59 +0100 Subject: [PATCH 2/8] fixing TrustRegion --- src/trustRegion.jl | 74 ++++++++++++++++++++++++++++++---------------- test/basictests.jl | 2 +- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 127785515..b20979ff9 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type, typeof(linsolve), typeof(precs), _unwrap_val(standardtag), @@ -141,6 +141,11 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, prob::probType trust_r::trustType max_trust_r::trustType + step_threshold::trustType + shrink_threshold::trustType + expand_threshold::trustType + shrink_factor::trustType + expand_factor::trustType loss::floatType loss_new::floatType H::jType @@ -158,10 +163,12 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, force_stop::Bool, maxiters::Int, internalnorm::INType, retcode::SciMLBase.ReturnCode.T, abstol::tolType, prob::probType, trust_r::trustType, - max_trust_r::trustType, loss::floatType, - loss_new::floatType, H::jType, g::resType, - shrink_counter::Int, step_size::uType, u_tmp::uType, - fu_new::resType, make_new_J::Bool, + max_trust_r::trustType, step_threshold::trustType, + shrink_threshold::trustType, expand_threshold::trustType, + shrink_factor::trustType, expand_factor::trustType, + loss::floatType, loss_new::floatType, H::jType, + g::resType, shrink_counter::Int, step_size::uType, + u_tmp::uType, fu_new::resType, make_new_J::Bool, r::floatType) where {iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, @@ -171,7 +178,10 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, }(f, alg, u, fu, p, uf, linsolve, J, jac_config, iter, force_stop, maxiters, internalnorm, retcode, - abstol, prob, trust_r, max_trust_r, loss, + abstol, prob, trust_r, max_trust_r, + step_threshold, shrink_threshold, + expand_threshold, shrink_factor, + expand_factor, loss, loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new, make_new_J, r) @@ -228,25 +238,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, loss = get_loss(fu) uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip)) + max_trust_radius = convert(eltype(u), alg.max_trust_radius) + initial_trust_radius = convert(eltype(u), alg.initial_trust_radius) + step_threshold = convert(eltype(u), alg.step_threshold) + shrink_threshold = convert(eltype(u), alg.shrink_threshold) + expand_threshold = convert(eltype(u), alg.expand_threshold) + shrink_factor = convert(eltype(u), alg.shrink_factor) + expand_factor = convert(eltype(u), alg.expand_factor) # Set default trust region radius if not specified - u_elType = uType <: Number ? uType : eltype(u) - max_trust_radius = u_elType(alg.max_trust_radius) - initial_trust_radius = u_elType(alg.initial_trust_radius) if iszero(max_trust_radius) - max_trust_radius = max(norm(fu), maximum(u) - minimum(u)) + max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u))) end if iszero(initial_trust_radius) - initial_trust_radius = max_trust_radius / 11 + initial_trust_radius = convert(eltype(u), max_trust_radius / 11) end + loss_new = loss H = ArrayInterfaceCore.undefmatrix(u) + g = zero(fu) + shrink_counter = 0 + step_size = zero(u) + fu_new = zero(fu) + make_new_J = true + r = loss return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config, 1, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, initial_trust_radius, - max_trust_radius, loss, loss, H, zero(fu), 0, zero(u), - u_tmp, zero(fu), true, - loss) + max_trust_radius, step_threshold, shrink_threshold, + expand_threshold, shrink_factor, expand_factor, loss, + loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new, + make_new_J, r) end function perform_step!(cache::TrustRegionCache{true}) @@ -295,7 +317,7 @@ function perform_step!(cache::TrustRegionCache{false}) end function trust_region_step!(cache::TrustRegionCache) - @unpack fu_new, u_tmp, step_size, g, H, loss, alg, max_trust_r = cache + @unpack fu_new, step_size, g, H, loss, max_trust_r = cache cache.loss_new = get_loss(fu_new) # Compute the ratio of the actual reduction to the predicted reduction. @@ -303,19 +325,19 @@ function trust_region_step!(cache::TrustRegionCache) @unpack r = cache # Update the trust region radius. - if r < alg.shrink_threshold - cache.trust_r *= alg.shrink_factor + if r < cache.shrink_threshold + cache.trust_r *= cache.shrink_factor cache.shrink_counter += 1 else cache.shrink_counter = 0 end - if r > alg.step_threshold + if r > cache.step_threshold take_step!(cache) cache.loss = cache.loss_new # Update the trust region radius. - if r > alg.expand_threshold - cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r) + if r > cache.expand_threshold + cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r) end cache.make_new_J = true diff --git a/test/basictests.jl b/test/basictests.jl index ed2907fdb..d0e959b6b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0) u0 = [1.0, 1.0] probN = NonlinearProblem{true}(ffiip, u0) solver = init(probN, TrustRegion(), abstol = 1e-9) -@test (@ballocated solve!(solver)) < 120 +@test (@ballocated solve!(solver)) < 200 # AD Tests using ForwardDiff From 9ce3d1e511282bdba3dc86fc3df6765650bd2736 Mon Sep 17 00:00:00 2001 From: Simon Carlson Date: Mon, 23 Jan 2023 15:36:29 +0100 Subject: [PATCH 3/8] fixing docs --- src/trustRegion.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index b20979ff9..a35be7807 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -3,13 +3,13 @@ TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) ``` From 8adda3894877a38ee3d256ad41899be428141758 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 24 Jan 2023 06:20:14 -0500 Subject: [PATCH 4/8] Fix merge error --- src/trustRegion.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index ed118719d..fe5ae23bc 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -174,8 +174,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, jType, JC, floatType, trustType, suType, tmpType} new{iip, fType, algType, uType, resType, pType, - INType, tolType, probType, ufType, L, jType, JC, floatType, trustType - }(f, alg, u, fu, p, uf, linsolve, J, + INType, tolType, probType, ufType, L, jType, JC, floatType, trustType, + suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J, jac_config, iter, force_stop, maxiters, internalnorm, retcode, abstol, prob, trust_r, max_trust_r, From 3f66a32a7c40c56d9bce0c256f3b9a3b0861027e Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 24 Jan 2023 06:27:18 -0500 Subject: [PATCH 5/8] Update trustRegion.jl --- src/trustRegion.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index fe5ae23bc..95bf78aa7 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -162,12 +162,12 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, force_stop::Bool, maxiters::Int, internalnorm::INType, retcode::SciMLBase.ReturnCode.T, abstol::tolType, prob::probType, trust_r::trustType, - max_trust_r::trustType, step_threshold::trustType, + max_trust_r::trustType, step_threshold::suType, shrink_threshold::trustType, expand_threshold::trustType, shrink_factor::trustType, expand_factor::trustType, loss::floatType, loss_new::floatType, H::jType, g::resType, shrink_counter::Int, step_size::uType, - u_tmp::uType, fu_new::resType, make_new_J::Bool, + u_tmp::tmpType, fu_new::resType, make_new_J::Bool, r::floatType) where {iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, From 30d8944af6d11f3b098374bb03dc4c996c0bf806 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 24 Jan 2023 15:41:41 -0500 Subject: [PATCH 6/8] Update src/trustRegion.jl --- src/trustRegion.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 95bf78aa7..aaeaf7d34 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -140,7 +140,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, prob::probType trust_r::trustType max_trust_r::trustType - step_threshold::trustType + step_threshold::suType shrink_threshold::trustType expand_threshold::trustType shrink_factor::trustType From 6947f6a91f1a6d527a7fce2a7ebc054aa5d39fe7 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 25 Jan 2023 07:05:17 -0500 Subject: [PATCH 7/8] split types --- src/trustRegion.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 95bf78aa7..90a19712e 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -121,7 +121,7 @@ end mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, jType, JC, floatType, - trustType, suType, tmpType} + trustType, suType, su2Type, tmpType} f::fType alg::algType u::uType @@ -140,7 +140,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, prob::probType trust_r::trustType max_trust_r::trustType - step_threshold::trustType + step_threshold::suType shrink_threshold::trustType expand_threshold::trustType shrink_factor::trustType @@ -150,7 +150,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, H::jType g::resType shrink_counter::Int - step_size::suType + step_size::su2Type u_tmp::tmpType fu_new::resType make_new_J::Bool @@ -166,16 +166,16 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, shrink_threshold::trustType, expand_threshold::trustType, shrink_factor::trustType, expand_factor::trustType, loss::floatType, loss_new::floatType, H::jType, - g::resType, shrink_counter::Int, step_size::uType, + g::resType, shrink_counter::Int, step_size::su2Type, u_tmp::tmpType, fu_new::resType, make_new_J::Bool, r::floatType) where {iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, jType, JC, floatType, trustType, - suType, tmpType} + suType, su2Type, tmpType} new{iip, fType, algType, uType, resType, pType, - INType, tolType, probType, ufType, L, jType, JC, floatType, trustType, - suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J, + INType, tolType, probType, ufType, L, jType, JC, floatType, + trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J, jac_config, iter, force_stop, maxiters, internalnorm, retcode, abstol, prob, trust_r, max_trust_r, From e438e125b5786755574df3fcd6dac83f1b5a1597 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 25 Jan 2023 10:09:10 -0500 Subject: [PATCH 8/8] format --- src/trustRegion.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 90a19712e..b0b9640ca 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -174,17 +174,17 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, jType, JC, floatType, trustType, suType, su2Type, tmpType} new{iip, fType, algType, uType, resType, pType, - INType, tolType, probType, ufType, L, jType, JC, floatType, - trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J, - jac_config, iter, force_stop, - maxiters, internalnorm, retcode, - abstol, prob, trust_r, max_trust_r, - step_threshold, shrink_threshold, - expand_threshold, shrink_factor, - expand_factor, loss, - loss_new, H, g, shrink_counter, - step_size, u_tmp, fu_new, - make_new_J, r) + INType, tolType, probType, ufType, L, jType, JC, floatType, + trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J, + jac_config, iter, force_stop, + maxiters, internalnorm, retcode, + abstol, prob, trust_r, max_trust_r, + step_threshold, shrink_threshold, + expand_threshold, shrink_factor, + expand_factor, loss, + loss_new, H, g, shrink_counter, + step_size, u_tmp, fu_new, + make_new_J, r) end end