diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 3e157f11a..2433a93dd 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -294,6 +294,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, p2 = convert(eltype(u), 0.1) # β p3 = convert(eltype(u), 0.15) # γ1 p4 = convert(eltype(u), 0.15) # γ2 + initial_trust_radius = convert(eltype(u), 1.0) elseif radius_update_scheme === RadiusUpdateSchemes.Yuan step_threshold = convert(eltype(u), 0.0001) shrink_threshold = convert(eltype(u), 0.25) @@ -302,6 +303,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, p2 = convert(eltype(u), 1/6) # c5 p3 = convert(eltype(u), 6.0) # c6 p4 = convert(eltype(u), 0.0) + if iip + auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) + else + if isa(u, Number) + g = ForwardDiff.derivative(x -> f(x, p), u) + else + g = auto_jacvec(x -> f(x, p), u, fu) + end + end + initial_trust_radius = convert(eltype(u), p1 * norm(g)) + elseif radius_update_scheme === RadiusUpdateSchemes.Fan + step_threshold = convert(eltype(u), 0.0001) + shrink_threshold = convert(eltype(u), 0.25) + expand_threshold = convert(eltype(u), 0.75) + p1 = convert(eltype(u), 0.1) # μ + p2 = convert(eltype(u), 1/4) # c5 + p3 = convert(eltype(u), 12) # c6 + p4 = convert(eltype(u), 1.0e18) # M + initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99)) end return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config, @@ -435,8 +455,29 @@ function trust_region_step!(cache::TrustRegionCache) if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end + #Fan's update scheme + elseif radius_update_scheme === RadiusUpdateSchemes.Fan + if r < cache.shrink_threshold + cache.p1 *= cache.p2 + cache.shrink_counter += 1 + elseif r > cache.expand_threshold + cache.p1 = min(cache.p1*cache.p3, cache.p4) + cache.shrink_counter = 0 + end - #elseif radius_update_scheme === RadiusUpdateSchemes.Bastin + if r > cache.step_threshold + take_step!(cache) + cache.loss = cache.loss_new + cache.make_new_J = true + else + cache.make_new_J = false + end + + @unpack p1 = cache + cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99) + if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ + cache.force_stop = true + end end end @@ -490,7 +531,7 @@ function jvp!(cache::TrustRegionCache{true}) if isa(u, Number) return value_derivative(x -> f(x, p), u) end - return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) + auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) g end diff --git a/test/basictests.jl b/test/basictests.jl index 292f14611..2d338ec3b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -189,7 +189,8 @@ function sf(u, p=nothing) end u0 = [1.0, 1.0] -radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan] +radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan, + RadiusUpdateSchemes.Fan] for radius_update_scheme in radius_update_schemes sol = benchmark_immutable(ff, cu0, radius_update_scheme) @@ -255,7 +256,6 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) end -## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS g = function (p) probN = NonlinearProblem{false}(f, csu0, p) sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9) @@ -267,6 +267,17 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) end +g = function (p) + probN = NonlinearProblem{false}(f, csu0, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-9) + return sol.u[end] +end + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + # Scalar f, u0 = (u, p) -> u * u - p, 1.0 @@ -309,6 +320,19 @@ for p in 1.1:0.1:100.0 @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) end +g = function (p) + probN = NonlinearProblem{false}(f, oftype(p, u0), p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10) + return sol.u +end + +@test ForwardDiff.derivative(g, 3.0) ≈ 1 / (2 * sqrt(3.0)) + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + f = (u, p) -> p[1] * u * u - p[2] t = (p) -> [sqrt(p[2] / p[1])] p = [0.9, 50.0] @@ -328,6 +352,22 @@ end @test gnewton(p) ≈ [sqrt(p[2] / p[1])] @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + # Iterator interface f = (u, p) -> u * u - p g = function (p_range) @@ -372,6 +412,9 @@ probN = NonlinearProblem(f, u0) @test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] ≈ sqrt(2.0) @test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] ≈ sqrt(2.0) +@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)).u[end] ≈ sqrt(2.0) +@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] ≈ sqrt(2.0) + for u0 in [1.0, [1, 1.0]] local f, probN, sol f = (u, p) -> u .* u .- 2.0 @@ -404,6 +447,16 @@ u = g(p) f(u, p) @test all(abs.(f(u, p)) .< 1e-10) +g = function (p) + probN = NonlinearProblem{false}(f, u0, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10) + return sol.u +end +p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +u = g(p) +f(u, p) +@test all(abs.(f(u, p)) .< 1e-10) + # Test kwars in `TrustRegion` max_trust_radius = [10.0, 100.0, 1000.0] initial_trust_radius = [10.0, 1.0, 0.1]