Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bastin's radius update scheme #191

Merged
merged 10 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
trustType, suType, su2Type, tmpType}
f::fType
alg::algType
u_prev::uType
u::uType
fu_prev::resType
fu::resType
p::pType
uf::ufType
Expand Down Expand Up @@ -172,7 +174,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
ϵ::floatType
stats::NLStats

function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
function TrustRegionCache{iip}(f::fType, alg::algType, u_prev::uType, u::uType,
fu_prev::resType, fu::resType, p::pType,
uf::ufType, linsolve::L, J::jType, jac_config::JC,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
Expand All @@ -194,7 +197,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
suType, su2Type, tmpType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
trustType, suType, su2Type, tmpType}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
jac_config, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, radius_update_scheme,
Expand Down Expand Up @@ -246,6 +249,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
else
u = deepcopy(prob.u0)
end
u_prev = deepcopy(u)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed by some algorithms though, right? So in many cases it can be made as an empty array?

f = prob.f
p = prob.p
if iip
Expand All @@ -254,6 +258,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
else
fu = f(u, p)
end
fu_prev = deepcopy(fu)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
Expand Down Expand Up @@ -325,9 +330,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p3 = convert(eltype(u), 12) # c6
p4 = convert(eltype(u), 1.0e18) # M
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
step_threshold = convert(eltype(u), 0.05)
shrink_threshold = convert(eltype(u), 0.05)
expand_threshold = convert(eltype(u), 0.9)
p1 = convert(eltype(u), 2.5) #alpha_1
p2 = convert(eltype(u), 0.25) # alpha_2
p3 = convert(eltype(u), 0) # not required
p4 = convert(eltype(u), 0) # not required
initial_trust_radius = convert(eltype(u), 1.0)
end

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
jac_config,
false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, radius_update_scheme,
initial_trust_radius,
Expand Down Expand Up @@ -388,6 +403,30 @@ function perform_step!(cache::TrustRegionCache{false})
return nothing
end

function retrospective_step!(cache::TrustRegionCache{true})
@unpack J, fu_prev, fu, u_prev, u = cache
jacobian!(J, cache)
mul!(cache.H, J, J)
mul!(cache.g, J, fu)
cache.stats.njacs += 1
@unpack H, g, step_size = cache

return -(get_loss(fu_prev) - get_loss(fu)) /
(step_size' * g + step_size' * H * step_size / 2)
end

function retrospective_step!(cache::TrustRegionCache{false})
@unpack J, fu_prev, fu, u_prev, u, f = cache
J = jacobian(cache, f)
cache.H = J * J
cache.g = J * fu
cache.stats.njacs += 1
@unpack H, g, step_size = cache

return -(get_loss(fu_prev) - get_loss(fu)) /
(step_size' * g + step_size' * H * step_size / 2)
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
cache.loss_new = get_loss(fu_new)
Expand Down Expand Up @@ -495,6 +534,23 @@ function trust_region_step!(cache::TrustRegionCache)
cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
if retrospective_step!(cache) >= cache.expand_threshold
cache.trust_r = max(cache.p1 * cache.internalnorm(step_size), cache.trust_r)
end

else
cache.make_new_J = false
cache.trust_r *= cache.p2
cache.shrink_counter += 1
end
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
end
end

Expand Down Expand Up @@ -526,12 +582,16 @@ function dogleg!(cache::TrustRegionCache)
end

function take_step!(cache::TrustRegionCache{true})
cache.u_prev .= cache.u
cache.u .= cache.u_tmp
cache.fu_prev .= cache.fu
cache.fu .= cache.fu_new
end

function take_step!(cache::TrustRegionCache{false})
cache.u_prev = cache.u
cache.u = cache.u_tmp
cache.fu_prev = cache.fu
cache.fu = cache.fu_new
end

Expand Down
58 changes: 56 additions & 2 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ end

u0 = [1.0, 1.0]
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei,
RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan]
RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]

for radius_update_scheme in radius_update_schemes
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
Expand Down Expand Up @@ -286,6 +285,18 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-9)
return sol.u[end]
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

Expand Down Expand Up @@ -344,6 +355,20 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-10)
return sol.u
end

@test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0))

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

f = (u, p) -> p[1] * u * u - p[2]
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
Expand Down Expand Up @@ -379,6 +404,14 @@ end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin))
return [sol.u]
end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
Expand Down Expand Up @@ -432,6 +465,11 @@ probN = NonlinearProblem(f, u0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] ≈
sqrt(2.0)

@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin)).u[end] ≈
sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff = false)).u[end] ≈
sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
Expand Down Expand Up @@ -475,6 +513,17 @@ u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-10)
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

# Test kwars in `TrustRegion`
max_trust_radius = [10.0, 100.0, 1000.0]
initial_trust_radius = [10.0, 1.0, 0.1]
Expand Down Expand Up @@ -542,6 +591,11 @@ for maxiters in maxiterations
@test iip == oop
end

for maxiters in maxiterations
iip, oop = iip_oop(ff, ffiip, u0, RadiusUpdateSchemes.Bastin, maxiters)
@test iip == oop
end

# --- LevenbergMarquardt tests ---

function benchmark_immutable(f, u0)
Expand Down
40 changes: 40 additions & 0 deletions test/convergencetests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using NonlinearSolve
using StaticArrays
using BenchmarkTools
using Test

using SciMLNLSolve

###-----Trust Region tests-----###

# some simple functions #
function f_oop(u, p)
u .* u .- p
end

function f_iip(du, u, p)
du .= u .* u .- p
end

function f_scalar(u, p)
u * u - p
end

u0 = [1.0, 1.0]
csu0 = 1.0
p = [2.0, 2.0]
radius_update_scheme = RadiusUpdateSchemes.Simple
tol = 1e-9

function convergence_test_oop(f, u0, p, radius_update_scheme)
prob = NonlinearProblem{false}(f, oftype(p, u0), p)
cache = init(prob,
TrustRegion(radius_update_scheme = radius_update_scheme),
abstol = 1e-9)
sol = solve!(cache)
return cache.internalnorm(cache.u_prev - cache.u), cache.iter, sol.retcode
end

residual, iterations, return_code = convergence_test_oop(f_oop, u0, p, radius_update_scheme)
@test return_code === ReturnCode.Success
@test residual ≈ tol