diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 153644121..b0b9640ca 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -3,13 +3,13 @@ TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) ``` @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type, typeof(linsolve), typeof(precs), _unwrap_val(standardtag), @@ -121,7 +121,7 @@ end mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, jType, JC, floatType, - trustType, suType, tmpType} + trustType, suType, su2Type, tmpType} f::fType alg::algType u::uType @@ -140,12 +140,17 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, prob::probType trust_r::trustType max_trust_r::trustType + step_threshold::suType + shrink_threshold::trustType + expand_threshold::trustType + shrink_factor::trustType + expand_factor::trustType loss::floatType loss_new::floatType H::jType g::resType shrink_counter::Int - step_size::suType + step_size::su2Type u_tmp::tmpType fu_new::resType make_new_J::Bool @@ -157,24 +162,29 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, force_stop::Bool, maxiters::Int, internalnorm::INType, retcode::SciMLBase.ReturnCode.T, abstol::tolType, prob::probType, trust_r::trustType, - max_trust_r::trustType, loss::floatType, - loss_new::floatType, H::jType, g::resType, - shrink_counter::Int, step_size::suType, u_tmp::tmpType, - fu_new::resType, make_new_J::Bool, + max_trust_r::trustType, step_threshold::suType, + shrink_threshold::trustType, expand_threshold::trustType, + shrink_factor::trustType, expand_factor::trustType, + loss::floatType, loss_new::floatType, H::jType, + g::resType, shrink_counter::Int, step_size::su2Type, + u_tmp::tmpType, fu_new::resType, make_new_J::Bool, r::floatType) where {iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, jType, JC, floatType, trustType, - suType, tmpType} + suType, su2Type, tmpType} new{iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, jType, JC, floatType, - trustType, suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J, - jac_config, iter, force_stop, - maxiters, internalnorm, retcode, - abstol, prob, trust_r, max_trust_r, loss, - loss_new, H, g, shrink_counter, - step_size, u_tmp, fu_new, - make_new_J, r) + trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J, + jac_config, iter, force_stop, + maxiters, internalnorm, retcode, + abstol, prob, trust_r, max_trust_r, + step_threshold, shrink_threshold, + expand_threshold, shrink_factor, + expand_factor, loss, + loss_new, H, g, shrink_counter, + step_size, u_tmp, fu_new, + make_new_J, r) end end @@ -228,15 +238,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, loss = get_loss(fu) uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip)) + max_trust_radius = convert(eltype(u), alg.max_trust_radius) + initial_trust_radius = convert(eltype(u), alg.initial_trust_radius) + step_threshold = convert(eltype(u), alg.step_threshold) + shrink_threshold = convert(eltype(u), alg.shrink_threshold) + expand_threshold = convert(eltype(u), alg.expand_threshold) + shrink_factor = convert(eltype(u), alg.shrink_factor) + expand_factor = convert(eltype(u), alg.expand_factor) # Set default trust region radius if not specified - max_trust_radius = alg.max_trust_radius - initial_trust_radius = alg.initial_trust_radius - if max_trust_radius == 0.0 - max_trust_radius = convert(typeof(max_trust_radius), - max(norm(fu), maximum(u) - minimum(u))) + if iszero(max_trust_radius) + max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u))) end - if initial_trust_radius == 0.0 - initial_trust_radius = max_trust_radius / 11 + if iszero(initial_trust_radius) + initial_trust_radius = convert(eltype(u), max_trust_radius / 11) end loss_new = loss @@ -251,8 +265,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config, 1, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, initial_trust_radius, - max_trust_radius, loss, loss_new, H, g, shrink_counter, - step_size, u_tmp, fu_new, make_new_J, r) + max_trust_radius, step_threshold, shrink_threshold, + expand_threshold, shrink_factor, expand_factor, loss, + loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new, + make_new_J, r) end function perform_step!(cache::TrustRegionCache{true}) @@ -301,7 +317,7 @@ function perform_step!(cache::TrustRegionCache{false}) end function trust_region_step!(cache::TrustRegionCache) - @unpack fu_new, step_size, g, H, loss, alg, max_trust_r = cache + @unpack fu_new, step_size, g, H, loss, max_trust_r = cache cache.loss_new = get_loss(fu_new) # Compute the ratio of the actual reduction to the predicted reduction. @@ -309,19 +325,19 @@ function trust_region_step!(cache::TrustRegionCache) @unpack r = cache # Update the trust region radius. - if r < alg.shrink_threshold - cache.trust_r *= alg.shrink_factor + if r < cache.shrink_threshold + cache.trust_r *= cache.shrink_factor cache.shrink_counter += 1 else cache.shrink_counter = 0 end - if r > alg.step_threshold + if r > cache.step_threshold take_step!(cache) cache.loss = cache.loss_new # Update the trust region radius. - if r > alg.expand_threshold - cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r) + if r > cache.expand_threshold + cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r) end cache.make_new_J = true diff --git a/test/basictests.jl b/test/basictests.jl index abf8eeb43..947975959 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0) u0 = [1.0, 1.0] probN = NonlinearProblem{true}(ffiip, u0) solver = init(probN, TrustRegion(), abstol = 1e-9) -@test (@ballocated solve!(solver)) < 120 +@test (@ballocated solve!(solver)) < 200 # AD Tests using ForwardDiff