diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 127785515..a35be7807 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -3,13 +3,13 @@ TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) ``` @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, - max_trust_radius::Real = 0.0, - initial_trust_radius::Real = 0.0, - step_threshold::Real = 0.1, - shrink_threshold::Real = 0.25, - expand_threshold::Real = 0.75, - shrink_factor::Real = 0.25, - expand_factor::Real = 2.0, + max_trust_radius::Real = 0 // 1, + initial_trust_radius::Real = 0 // 1, + step_threshold::Real = 1 // 10, + shrink_threshold::Real = 1 // 4, + expand_threshold::Real = 3 // 4, + shrink_factor::Real = 1 // 4, + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type, typeof(linsolve), typeof(precs), _unwrap_val(standardtag), @@ -141,6 +141,11 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, prob::probType trust_r::trustType max_trust_r::trustType + step_threshold::trustType + shrink_threshold::trustType + expand_threshold::trustType + shrink_factor::trustType + expand_factor::trustType loss::floatType loss_new::floatType H::jType @@ -158,10 +163,12 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, force_stop::Bool, maxiters::Int, internalnorm::INType, retcode::SciMLBase.ReturnCode.T, abstol::tolType, prob::probType, trust_r::trustType, - max_trust_r::trustType, loss::floatType, - loss_new::floatType, H::jType, g::resType, - shrink_counter::Int, step_size::uType, u_tmp::uType, - fu_new::resType, make_new_J::Bool, + max_trust_r::trustType, step_threshold::trustType, + shrink_threshold::trustType, expand_threshold::trustType, + shrink_factor::trustType, expand_factor::trustType, + loss::floatType, loss_new::floatType, H::jType, + g::resType, shrink_counter::Int, step_size::uType, + u_tmp::uType, fu_new::resType, make_new_J::Bool, r::floatType) where {iip, fType, algType, uType, resType, pType, INType, tolType, probType, ufType, L, @@ -171,7 +178,10 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType, }(f, alg, u, fu, p, uf, linsolve, J, jac_config, iter, force_stop, maxiters, internalnorm, retcode, - abstol, prob, trust_r, max_trust_r, loss, + abstol, prob, trust_r, max_trust_r, + step_threshold, shrink_threshold, + expand_threshold, shrink_factor, + expand_factor, loss, loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new, make_new_J, r) @@ -228,25 +238,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, loss = get_loss(fu) uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip)) + max_trust_radius = convert(eltype(u), alg.max_trust_radius) + initial_trust_radius = convert(eltype(u), alg.initial_trust_radius) + step_threshold = convert(eltype(u), alg.step_threshold) + shrink_threshold = convert(eltype(u), alg.shrink_threshold) + expand_threshold = convert(eltype(u), alg.expand_threshold) + shrink_factor = convert(eltype(u), alg.shrink_factor) + expand_factor = convert(eltype(u), alg.expand_factor) # Set default trust region radius if not specified - u_elType = uType <: Number ? uType : eltype(u) - max_trust_radius = u_elType(alg.max_trust_radius) - initial_trust_radius = u_elType(alg.initial_trust_radius) if iszero(max_trust_radius) - max_trust_radius = max(norm(fu), maximum(u) - minimum(u)) + max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u))) end if iszero(initial_trust_radius) - initial_trust_radius = max_trust_radius / 11 + initial_trust_radius = convert(eltype(u), max_trust_radius / 11) end + loss_new = loss H = ArrayInterfaceCore.undefmatrix(u) + g = zero(fu) + shrink_counter = 0 + step_size = zero(u) + fu_new = zero(fu) + make_new_J = true + r = loss return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config, 1, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, initial_trust_radius, - max_trust_radius, loss, loss, H, zero(fu), 0, zero(u), - u_tmp, zero(fu), true, - loss) + max_trust_radius, step_threshold, shrink_threshold, + expand_threshold, shrink_factor, expand_factor, loss, + loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new, + make_new_J, r) end function perform_step!(cache::TrustRegionCache{true}) @@ -295,7 +317,7 @@ function perform_step!(cache::TrustRegionCache{false}) end function trust_region_step!(cache::TrustRegionCache) - @unpack fu_new, u_tmp, step_size, g, H, loss, alg, max_trust_r = cache + @unpack fu_new, step_size, g, H, loss, max_trust_r = cache cache.loss_new = get_loss(fu_new) # Compute the ratio of the actual reduction to the predicted reduction. @@ -303,19 +325,19 @@ function trust_region_step!(cache::TrustRegionCache) @unpack r = cache # Update the trust region radius. - if r < alg.shrink_threshold - cache.trust_r *= alg.shrink_factor + if r < cache.shrink_threshold + cache.trust_r *= cache.shrink_factor cache.shrink_counter += 1 else cache.shrink_counter = 0 end - if r > alg.step_threshold + if r > cache.step_threshold take_step!(cache) cache.loss = cache.loss_new # Update the trust region radius. - if r > alg.expand_threshold - cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r) + if r > cache.expand_threshold + cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r) end cache.make_new_J = true diff --git a/test/basictests.jl b/test/basictests.jl index ed2907fdb..d0e959b6b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0) u0 = [1.0, 1.0] probN = NonlinearProblem{true}(ffiip, u0) solver = init(probN, TrustRegion(), abstol = 1e-9) -@test (@ballocated solve!(solver)) < 120 +@test (@ballocated solve!(solver)) < 200 # AD Tests using ForwardDiff