Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type instability fix #132

Merged
merged 12 commits into from
Jan 25, 2023
98 changes: 57 additions & 41 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
```

Expand Down Expand Up @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
Expand Down Expand Up @@ -140,6 +140,11 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
prob::probType
trust_r::trustType
max_trust_r::trustType
step_threshold::trustType
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
shrink_threshold::trustType
expand_threshold::trustType
shrink_factor::trustType
expand_factor::trustType
loss::floatType
loss_new::floatType
H::jType
Expand All @@ -157,24 +162,29 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
max_trust_r::trustType, loss::floatType,
loss_new::floatType, H::jType, g::resType,
shrink_counter::Int, step_size::suType, u_tmp::tmpType,
fu_new::resType, make_new_J::Bool,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
loss::floatType, loss_new::floatType, H::jType,
g::resType, shrink_counter::Int, step_size::uType,
u_tmp::tmpType, fu_new::resType, make_new_J::Bool,
r::floatType) where {iip, fType, algType, uType,
resType, pType, INType,
tolType, probType, ufType, L,
jType, JC, floatType, trustType,
suType, tmpType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
INType, tolType, probType, ufType, L, jType, JC, floatType, trustType,
suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
end
end

Expand Down Expand Up @@ -228,15 +238,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
expand_threshold = convert(eltype(u), alg.expand_threshold)
shrink_factor = convert(eltype(u), alg.shrink_factor)
expand_factor = convert(eltype(u), alg.expand_factor)
# Set default trust region radius if not specified
max_trust_radius = alg.max_trust_radius
initial_trust_radius = alg.initial_trust_radius
if max_trust_radius == 0.0
max_trust_radius = convert(typeof(max_trust_radius),
max(norm(fu), maximum(u) - minimum(u)))
if iszero(max_trust_radius)
max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u)))
end
if initial_trust_radius == 0.0
initial_trust_radius = max_trust_radius / 11
if iszero(initial_trust_radius)
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
end

loss_new = loss
Expand All @@ -251,8 +265,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new, make_new_J, r)
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
make_new_J, r)
end

function perform_step!(cache::TrustRegionCache{true})
Expand Down Expand Up @@ -301,27 +317,27 @@ function perform_step!(cache::TrustRegionCache{false})
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, alg, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < alg.shrink_threshold
cache.trust_r *= alg.shrink_factor
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > alg.step_threshold
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > alg.expand_threshold
cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r)
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0)
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 120
@test (@ballocated solve!(solver)) < 200

# AD Tests
using ForwardDiff
Expand Down