Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Including jvp for radius update schemes #172

Merged
merged 12 commits into from
Apr 3, 2023
2 changes: 2 additions & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@ function jacobian_autodiff(f, x::AbstractArray, nonlinfun, alg)
jac_prototype = jac_prototype, chunksize = chunk_size),
num_of_chunks)
end


41 changes: 31 additions & 10 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ EnumX.@enumx RadiusUpdateSchemes begin
Hei
Yuan
Bastin
Fan
end

struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
Expand Down Expand Up @@ -234,7 +235,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
args...;
alias_u0 = false,
maxiters = 1000,
abstol = 1e-6,
abstol = 1e-8,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
if alias_u0
Expand Down Expand Up @@ -301,7 +302,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p2 = convert(eltype(u), 1/6) # c5
p3 = convert(eltype(u), 6.0) # c6
p4 = convert(eltype(u), 0.0)
end
end

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
Expand Down Expand Up @@ -402,10 +403,12 @@ function trust_region_step!(cache::TrustRegionCache)
@unpack shrink_threshold, p1, p2, p3, p4 = cache
if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) < cache.trust_r
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size)

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end

Expand All @@ -416,19 +419,20 @@ function trust_region_step!(cache::TrustRegionCache)
cache.shrink_counter += 1
elseif r >= cache.expand_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
cache.p1 = cache.p3 * cache.p1
cache.shrink_counter = 0
end
@unpack p1, fu, f, J = cache
#cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?


if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined

@unpack p1= cache
cache.trust_r = p1 * cache.internalnorm(jvp!(cache))
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end

Expand All @@ -441,7 +445,7 @@ function dogleg!(cache::TrustRegionCache)

# Test if the full step is within the trust region.
if norm(u_tmp) ≤ trust_r
cache.step_size = u_tmp
cache.step_size = deepcopy(u_tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be a .=?

return
end

Expand Down Expand Up @@ -473,6 +477,23 @@ function take_step!(cache::TrustRegionCache{false})
cache.fu = cache.fu_new
end

function jvp!(cache::TrustRegionCache{false})
@unpack f, u, fu, p = cache
if isa(u, Number)
return value_derivative(x -> f(x, p), u)
end
return auto_jacvec(x -> f(x, p), u, fu)
end

function jvp!(cache::TrustRegionCache{true})
@unpack g, f, u, fu, p = cache
if isa(u, Number)
return value_derivative(x -> f(x, p), u)
end
return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
g
end

function SciMLBase.solve!(cache::TrustRegionCache)
while !cache.force_stop && cache.iter < cache.maxiters &&
cache.shrink_counter < cache.alg.max_shrink_times
Expand Down
127 changes: 99 additions & 28 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,61 +163,69 @@ end

# --- TrustRegion tests ---

function benchmark_immutable(f, u0)
function benchmark_immutable(f, u0, radius_update_scheme)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
sol = solve!(solver)
end

function benchmark_mutable(f, u0)
function benchmark_mutable(f, u0, radius_update_scheme)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
sol = solve!(solver)
end

function benchmark_scalar(f, u0)
function benchmark_scalar(f, u0, radius_update_scheme)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, TrustRegion(), abstol = 1e-9))
sol = (solve(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9))
end

function ff(u, p)
function ff(u, p=nothing)
u .* u .- 2
end

function sf(u, p)
function sf(u, p=nothing)
u * u - 2
end

u0 = [1.0, 1.0]
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan]

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_mutable(ff, u0)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test abs(sol.u * sol.u - 2) < 1e-9
for radius_update_scheme in radius_update_schemes
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_mutable(ff, u0, radius_update_scheme)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_scalar(sf, csu0, radius_update_scheme)
@test sol.retcode === ReturnCode.Success
@test abs(sol.u * sol.u - 2) < 1e-9
end

function benchmark_inplace(f, u0)

function benchmark_inplace(f, u0, radius_update_scheme)
probN = NonlinearProblem{true}(f, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
sol = solve!(solver)
end

function ffiip(du, u, p)
function ffiip(du, u, p=nothing)
du .= u .* u .- 2
end
u0 = [1.0, 1.0]

sol = benchmark_inplace(ffiip, u0)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
for radius_update_scheme in radius_update_schemes
sol = benchmark_inplace(ffiip, u0, radius_update_scheme)
@test sol.retcode === ReturnCode.Success
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
end

u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 200
for radius_update_scheme in radius_update_schemes
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(radius_update_scheme = radius_update_scheme), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 200
end

# AD Tests
using ForwardDiff
Expand All @@ -236,6 +244,29 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei), abstol = 1e-9)
return sol.u[end]
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9)
return sol.u[end]
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

Expand All @@ -252,6 +283,32 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei), abstol = 1e-10)
return sol.u
end

@test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0))

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-10)
return sol.u
end

@test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0))

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

f = (u, p) -> p[1] * u * u - p[2]
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
Expand All @@ -263,6 +320,14 @@ end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei))
return [sol.u]
end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
Expand Down Expand Up @@ -295,12 +360,18 @@ p = range(0.01, 2, length = 200)
@test g(p) ≈ sqrt.(p)

# Error Checks
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
f, u0 = (u, p) -> u .* u .- 2, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, TrustRegion()).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(; autodiff = false)).u[end] ≈ sqrt(2.0)

@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Hei)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Hei, autodiff = false)).u[end] ≈ sqrt(2.0)

@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] ≈ sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
Expand Down