Skip to content

Commit

Permalink
Merge pull request #132 from daviehh/utype
Browse files Browse the repository at this point in the history
type instability fix
  • Loading branch information
ChrisRackauckas authored Jan 25, 2023
2 parents e1c35f5 + e438e12 commit 2ab52d6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
102 changes: 59 additions & 43 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
```
Expand Down Expand Up @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
Expand All @@ -121,7 +121,7 @@ end

mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, tmpType}
trustType, suType, su2Type, tmpType}
f::fType
alg::algType
u::uType
Expand All @@ -140,12 +140,17 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
prob::probType
trust_r::trustType
max_trust_r::trustType
step_threshold::suType
shrink_threshold::trustType
expand_threshold::trustType
shrink_factor::trustType
expand_factor::trustType
loss::floatType
loss_new::floatType
H::jType
g::resType
shrink_counter::Int
step_size::suType
step_size::su2Type
u_tmp::tmpType
fu_new::resType
make_new_J::Bool
Expand All @@ -157,24 +162,29 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
max_trust_r::trustType, loss::floatType,
loss_new::floatType, H::jType, g::resType,
shrink_counter::Int, step_size::suType, u_tmp::tmpType,
fu_new::resType, make_new_J::Bool,
max_trust_r::trustType, step_threshold::suType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
loss::floatType, loss_new::floatType, H::jType,
g::resType, shrink_counter::Int, step_size::su2Type,
u_tmp::tmpType, fu_new::resType, make_new_J::Bool,
r::floatType) where {iip, fType, algType, uType,
resType, pType, INType,
tolType, probType, ufType, L,
jType, JC, floatType, trustType,
suType, tmpType}
suType, su2Type, tmpType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
end
end

Expand Down Expand Up @@ -228,15 +238,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
expand_threshold = convert(eltype(u), alg.expand_threshold)
shrink_factor = convert(eltype(u), alg.shrink_factor)
expand_factor = convert(eltype(u), alg.expand_factor)
# Set default trust region radius if not specified
max_trust_radius = alg.max_trust_radius
initial_trust_radius = alg.initial_trust_radius
if max_trust_radius == 0.0
max_trust_radius = convert(typeof(max_trust_radius),
max(norm(fu), maximum(u) - minimum(u)))
if iszero(max_trust_radius)
max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u)))
end
if initial_trust_radius == 0.0
initial_trust_radius = max_trust_radius / 11
if iszero(initial_trust_radius)
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
end

loss_new = loss
Expand All @@ -251,8 +265,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new, make_new_J, r)
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
make_new_J, r)
end

function perform_step!(cache::TrustRegionCache{true})
Expand Down Expand Up @@ -301,27 +317,27 @@ function perform_step!(cache::TrustRegionCache{false})
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, alg, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < alg.shrink_threshold
cache.trust_r *= alg.shrink_factor
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > alg.step_threshold
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > alg.expand_threshold
cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r)
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0)
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 120
@test (@ballocated solve!(solver)) < 200

# AD Tests
using ForwardDiff
Expand Down

0 comments on commit 2ab52d6

Please sign in to comment.