Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trust Region - Fan's method #178

Merged
merged 4 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p2 = convert(eltype(u), 0.1) # β
p3 = convert(eltype(u), 0.15) # γ1
p4 = convert(eltype(u), 0.15) # γ2
initial_trust_radius = convert(eltype(u), 1.0)
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
step_threshold = convert(eltype(u), 0.0001)
shrink_threshold = convert(eltype(u), 0.25)
Expand All @@ -302,6 +303,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
p2 = convert(eltype(u), 1/6) # c5
p3 = convert(eltype(u), 6.0) # c6
p4 = convert(eltype(u), 0.0)
if iip
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
else
if isa(u, Number)
g = ForwardDiff.derivative(x -> f(x, p), u)
else
g = auto_jacvec(x -> f(x, p), u, fu)
end
end
initial_trust_radius = convert(eltype(u), p1 * norm(g))
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
step_threshold = convert(eltype(u), 0.0001)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.75)
p1 = convert(eltype(u), 0.1) # μ
p2 = convert(eltype(u), 1/4) # c5
p3 = convert(eltype(u), 12) # c6
p4 = convert(eltype(u), 1.0e18) # M
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
end

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
Expand Down Expand Up @@ -435,8 +455,29 @@ function trust_region_step!(cache::TrustRegionCache)
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end
#Fan's update scheme
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
if r < cache.shrink_threshold
cache.p1 *= cache.p2
cache.shrink_counter += 1
elseif r > cache.expand_threshold
cache.p1 = min(cache.p1*cache.p3, cache.p4)
cache.shrink_counter = 0
end

#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new
cache.make_new_J = true
else
cache.make_new_J = false
end

@unpack p1 = cache
cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99)
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
cache.force_stop = true
end
end
end

Expand Down Expand Up @@ -490,7 +531,7 @@ function jvp!(cache::TrustRegionCache{true})
if isa(u, Number)
return value_derivative(x -> f(x, p), u)
end
return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
g
end

Expand Down
57 changes: 55 additions & 2 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ function sf(u, p=nothing)
end

u0 = [1.0, 1.0]
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan]
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan]

for radius_update_scheme in radius_update_schemes
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
Expand Down Expand Up @@ -255,7 +256,6 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9)
Expand All @@ -267,6 +267,17 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-9)
return sol.u[end]
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

Expand Down Expand Up @@ -309,6 +320,19 @@ for p in 1.1:0.1:100.0
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
return sol.u
end

@test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0))

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p))
end

f = (u, p) -> p[1] * u * u - p[2]
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
Expand All @@ -328,6 +352,22 @@ end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan))
return [sol.u]
end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan))
return [sol.u]
end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
Expand Down Expand Up @@ -372,6 +412,9 @@ probN = NonlinearProblem(f, u0)
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] ≈ sqrt(2.0)

@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)).u[end] ≈ sqrt(2.0)
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] ≈ sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
Expand Down Expand Up @@ -404,6 +447,16 @@ u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u = g(p)
f(u, p)
@test all(abs.(f(u, p)) .< 1e-10)

# Test kwars in `TrustRegion`
max_trust_radius = [10.0, 100.0, 1000.0]
initial_trust_radius = [10.0, 1.0, 0.1]
Expand Down