diff --git a/src/raphson.jl b/src/raphson.jl index bcbae54d5..bfec2830a 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -198,3 +198,22 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache) SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu; retcode = cache.retcode) end + +function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u0; p = cache.p, + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.u, u0) + cache.f(cache.fu, cache.u, p) + else + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter + cache.u = u0 + cache.fu = cache.f(cache.u, p) + end + cache.abstol = abstol + cache.maxiters = maxiters + cache.iter = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + return cache +end diff --git a/test/basictests.jl b/test/basictests.jl index 947975959..ebaa80916 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -111,6 +111,37 @@ end @test gnewton(p) ≈ [sqrt(p[2] / p[1])] @test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) +# Iterator interface +f = (u, p) -> u * u - p +g = function (p_range) + probN = NonlinearProblem{false}(f, 0.5, p_range[begin]) + cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, cache.u; p = p) + sol = solve!(cache) + sols[i] = sol.u + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + +f = (res, u, p) -> (res[begin] = u[1] * u[1] - p) +g = function (p_range) + probN = NonlinearProblem{true}(f, [0.5], p_range[begin]) + cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, [cache.u[1]]; p = p) + sol = solve!(cache) + sols[i] = sol.u[1] + end + return sols +end +p = range(0.01, 2, length = 200) +@test g(p) ≈ sqrt.(p) + # Error Checks f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]