Skip to content

Commit

Permalink
Merge pull request #191 from yash2798/ys/bastin_new
Browse files Browse the repository at this point in the history
Bastin's radius update scheme
  • Loading branch information
ChrisRackauckas authored Jun 21, 2023
2 parents 8d009b9 + a8f4c33 commit 0abdc34
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 5 deletions.
66 changes: 63 additions & 3 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
trustType, suType, su2Type, tmpType}
f::fType
alg::algType
u_prev::uType
u::uType
fu_prev::resType
fu::resType
p::pType
uf::ufType
Expand Down Expand Up @@ -172,7 +174,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
ϵ::floatType
stats::NLStats

function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
function TrustRegionCache{iip}(f::fType, alg::algType, u_prev::uType, u::uType,
fu_prev::resType, fu::resType, p::pType,
uf::ufType, linsolve::L, J::jType, jac_config::JC,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
Expand All @@ -194,7 +197,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
suType, su2Type, tmpType}
new{iip, fType, algType, uType, resType, pType,
INType, tolType, probType, ufType, L, jType, JC, floatType,
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
trustType, suType, su2Type, tmpType}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
jac_config, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, radius_update_scheme,
Expand Down Expand Up @@ -246,6 +249,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
else
u = deepcopy(prob.u0)
end
u_prev = zero(u)
f = prob.f
p = prob.p
if iip
Expand All @@ -254,6 +258,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
else
fu = f(u, p)
end
fu_prev = zero(fu)

loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
Expand Down Expand Up @@ -325,9 +330,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p3 = convert(eltype(u), 12) # c6
p4 = convert(eltype(u), 1.0e18) # M
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
step_threshold = convert(eltype(u), 0.05)
shrink_threshold = convert(eltype(u), 0.05)
expand_threshold = convert(eltype(u), 0.9)
p1 = convert(eltype(u), 2.5) #alpha_1
p2 = convert(eltype(u), 0.25) # alpha_2
p3 = convert(eltype(u), 0) # not required
p4 = convert(eltype(u), 0) # not required
initial_trust_radius = convert(eltype(u), 1.0)
end

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
jac_config,
false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, radius_update_scheme,
initial_trust_radius,
Expand Down Expand Up @@ -388,6 +403,30 @@ function perform_step!(cache::TrustRegionCache{false})
return nothing
end

function retrospective_step!(cache::TrustRegionCache{true})
@unpack J, fu_prev, fu, u_prev, u = cache
jacobian!(J, cache)
mul!(cache.H, J, J)
mul!(cache.g, J, fu)
cache.stats.njacs += 1
@unpack H, g, step_size = cache

return -(get_loss(fu_prev) - get_loss(fu)) /
(step_size' * g + step_size' * H * step_size / 2)
end

function retrospective_step!(cache::TrustRegionCache{false})
@unpack J, fu_prev, fu, u_prev, u, f = cache
J = jacobian(cache, f)
cache.H = J * J
cache.g = J * fu
cache.stats.njacs += 1
@unpack H, g, step_size = cache

return -(get_loss(fu_prev) - get_loss(fu)) /
(step_size' * g + step_size' * H * step_size / 2)
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
cache.loss_new = get_loss(fu_new)
Expand Down Expand Up @@ -495,6 +534,23 @@ function trust_region_step!(cache::TrustRegionCache)
cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
if retrospective_step!(cache) >= cache.expand_threshold
cache.trust_r = max(cache.p1 * cache.internalnorm(step_size), cache.trust_r)
end

else
cache.make_new_J = false
cache.trust_r *= cache.p2
cache.shrink_counter += 1
end
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
end
end

Expand Down Expand Up @@ -526,12 +582,16 @@ function dogleg!(cache::TrustRegionCache)
end

function take_step!(cache::TrustRegionCache{true})
cache.u_prev .= cache.u
cache.u .= cache.u_tmp
cache.fu_prev .= cache.fu
cache.fu .= cache.fu_new
end

function take_step!(cache::TrustRegionCache{false})
cache.u_prev = cache.u
cache.u = cache.u_tmp
cache.fu_prev = cache.fu
cache.fu = cache.fu_new
end

Expand Down
58 changes: 56 additions & 2 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ end

u0 = [1.0, 1.0]
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei,
RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan]
RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]

for radius_update_scheme in radius_update_schemes
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
Expand Down Expand Up @@ -286,6 +285,18 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-9)
return sol.u[end]
end

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

Expand Down Expand Up @@ -344,6 +355,20 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-10)
return sol.u
end

@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
end

f = (u, p) -> p[1] * u * u - p[2]
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
Expand Down Expand Up @@ -379,6 +404,14 @@ end
@test gnewton(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin))
return [sol.u]
end
@test gnewton(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
Expand Down Expand Up @@ -432,6 +465,11 @@ probN = NonlinearProblem(f, u0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end]
sqrt(2.0)

@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin)).u[end]
sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff = false)).u[end]
sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
Expand Down Expand Up @@ -475,6 +513,17 @@ u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
abstol = 1e-10)
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

# Test kwars in `TrustRegion`
max_trust_radius = [10.0, 100.0, 1000.0]
initial_trust_radius = [10.0, 1.0, 0.1]
Expand Down Expand Up @@ -542,6 +591,11 @@ for maxiters in maxiterations
@test iip == oop
end

for maxiters in maxiterations
iip, oop = iip_oop(ff, ffiip, u0, RadiusUpdateSchemes.Bastin, maxiters)
@test iip == oop
end

# --- LevenbergMarquardt tests ---

function benchmark_immutable(f, u0)
Expand Down
40 changes: 40 additions & 0 deletions test/convergencetests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using NonlinearSolve
using StaticArrays
using BenchmarkTools
using Test

using SciMLNLSolve

###-----Trust Region tests-----###

# some simple functions #
function f_oop(u, p)
u .* u .- p
end

function f_iip(du, u, p)
du .= u .* u .- p
end

function f_scalar(u, p)
u * u - p
end

u0 = [1.0, 1.0]
csu0 = 1.0
p = [2.0, 2.0]
radius_update_scheme = RadiusUpdateSchemes.Simple
tol = 1e-9

function convergence_test_oop(f, u0, p, radius_update_scheme)
prob = NonlinearProblem{false}(f, oftype(p, u0), p)
cache = init(prob,
TrustRegion(radius_update_scheme = radius_update_scheme),
abstol = 1e-9)
sol = solve!(cache)
return cache.internalnorm(cache.u_prev - cache.u), cache.iter, sol.retcode
end

residual, iterations, return_code = convergence_test_oop(f_oop, u0, p, radius_update_scheme)
@test return_code === ReturnCode.Success
@test residual tol

0 comments on commit 0abdc34

Please sign in to comment.