Skip to content

Commit

Permalink
Added reinit! for TrustRegionCache
Browse files Browse the repository at this point in the history
  • Loading branch information
dawbarton committed Mar 29, 2023
1 parent 0f9858c commit c414a8f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
36 changes: 31 additions & 5 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ function trust_region_step!(cache::TrustRegionCache)
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

if radius_update_scheme === RadiusUpdateSchemes.Simple
if radius_update_scheme === RadiusUpdateSchemes.Simple
# Update the trust region radius.
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
Expand All @@ -389,13 +389,13 @@ function trust_region_step!(cache::TrustRegionCache)
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end

elseif radius_update_scheme === RadiusUpdateSchemes.Hei
if r > cache.step_threshold
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else
else
cache.make_new_J = false
end
# Hei's radius update scheme
Expand Down Expand Up @@ -427,7 +427,7 @@ function trust_region_step!(cache::TrustRegionCache)
else
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
cache.force_stop = true
end
Expand Down Expand Up @@ -489,3 +489,29 @@ function SciMLBase.solve!(cache::TrustRegionCache)
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
retcode = cache.retcode)
end

function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.iter = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
cache.make_new_J = true
cache.loss = get_loss(cache.fu)
cache.shrink_counter = 0
cache.trust_r = convert(eltype(cache.u), cache.alg.initial_trust_radius)
if iszero(cache.trust_r)
cache.trust_r = convert(eltype(cache.u), cache.max_trust_r / 11)
end
return cache
end
31 changes: 31 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,37 @@ end
@test gnewton(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
probN = NonlinearProblem{false}(f, 0.5, p_range[begin])
cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, cache.u; p = p)
sol = solve!(cache)
sols[i] = sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test g(p) sqrt.(p)

f = (res, u, p) -> (res[begin] = u[1] * u[1] - p)
g = function (p_range)
probN = NonlinearProblem{true}(f, [0.5], p_range[begin])
cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, [cache.u[1]]; p = p)
sol = solve!(cache)
sols[i] = sol.u[1]
end
return sols
end
p = range(0.01, 2, length = 200)
@test g(p) sqrt.(p)

# Error Checks
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)
Expand Down

0 comments on commit c414a8f

Please sign in to comment.