diff --git a/src/raphson.jl b/src/raphson.jl index bfec2830a..08a156190 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -199,7 +199,7 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache) retcode = cache.retcode) end -function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u0; p = cache.p, +function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p, abstol = cache.abstol, maxiters = cache.maxiters) where {iip} cache.p = p if iip diff --git a/src/trustRegion.jl b/src/trustRegion.jl index fc1662fc0..b84776dc9 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -489,3 +489,29 @@ function SciMLBase.solve!(cache::TrustRegionCache) SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu; retcode = cache.retcode) end + +function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p, + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.u, u0) + cache.f(cache.fu, cache.u, p) + else + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter + cache.u = u0 + cache.fu = cache.f(cache.u, p) + end + cache.abstol = abstol + cache.maxiters = maxiters + cache.iter = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + cache.make_new_J = true + cache.loss = get_loss(cache.fu) + cache.shrink_counter = 0 + cache.trust_r = convert(eltype(cache.u), cache.alg.initial_trust_radius) + if iszero(cache.trust_r) + cache.trust_r = convert(eltype(cache.u), cache.max_trust_r / 11) + end + return cache +end diff --git a/test/basictests.jl b/test/basictests.jl index ebaa80916..b6726798e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -263,6 +263,37 @@ end @test gnewton(p) ≈ [sqrt(p[2] / p[1])] @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) +# Iterator interface +f = (u, p) -> u * u - p +g = function (p_range) + probN = NonlinearProblem{false}(f, 0.5, p_range[begin]) + cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, cache.u; p = p) + sol = solve!(cache) + sols[i] = sol.u + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + +f = (res, u, p) -> (res[begin] = u[1] * u[1] - p) +g = function (p_range) + probN = NonlinearProblem{true}(f, [0.5], p_range[begin]) + cache = init(probN, TrustRegion(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, [cache.u[1]]; p = p) + sol = solve!(cache) + sols[i] = sol.u[1] + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + # Error Checks f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] probN = NonlinearProblem(f, u0)