diff --git a/src/trustRegion.jl b/src/trustRegion.jl index ef3446ebe..2433a93dd 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -304,11 +304,15 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, p3 = convert(eltype(u), 6.0) # c6 p4 = convert(eltype(u), 0.0) if iip - J = ForwardDiff.jacobian(f, fu, u) + auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) else - J = ForwardDiff.jacobian(f, u) + if isa(u, Number) + g = ForwardDiff.derivative(x -> f(x, p), u) + else + g = auto_jacvec(x -> f(x, p), u, fu) + end end - initial_trust_radius = convert(eltype(u), p1 * norm(J * fu)) + initial_trust_radius = convert(eltype(u), p1 * norm(g)) elseif radius_update_scheme === RadiusUpdateSchemes.Fan step_threshold = convert(eltype(u), 0.0001) shrink_threshold = convert(eltype(u), 0.25) @@ -527,7 +531,7 @@ function jvp!(cache::TrustRegionCache{true}) if isa(u, Number) return value_derivative(x -> f(x, p), u) end - return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) + auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu) g end diff --git a/test/basictests.jl b/test/basictests.jl index a7c0f9e92..2d338ec3b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -352,6 +352,22 @@ end @test gnewton(p) ≈ [sqrt(p[2] / p[1])] @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + # Iterator interface f = (u, p) -> u * u - p g = function (p_range) @@ -396,6 +412,9 @@ probN = NonlinearProblem(f, u0) @test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] ≈ sqrt(2.0) @test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] ≈ sqrt(2.0) +@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)).u[end] ≈ sqrt(2.0) +@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] ≈ sqrt(2.0) + for u0 in [1.0, [1, 1.0]] local f, probN, sol f = (u, p) -> u .* u .- 2.0 @@ -428,6 +447,16 @@ u = g(p) f(u, p) @test all(abs.(f(u, p)) .< 1e-10) +g = function (p) + probN = NonlinearProblem{false}(f, u0, p) + sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10) + return sol.u +end +p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +u = g(p) +f(u, p) +@test all(abs.(f(u, p)) .< 1e-10) + # Test kwars in `TrustRegion` max_trust_radius = [10.0, 100.0, 1000.0] initial_trust_radius = [10.0, 1.0, 0.1]