Skip to content

Commit

Permalink
Added reinit! for NewtonRaphson
Browse files Browse the repository at this point in the history
  • Loading branch information
dawbarton committed Mar 28, 2023
1 parent 85b94ac commit cf1486f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,22 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
retcode = cache.retcode)
end

function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u0; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.iter = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
31 changes: 31 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ end
@test gnewton(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)

# Iterator interface
f = (u, p) -> u * u - p
g = function (p_range)
probN = NonlinearProblem{false}(f, 0.5, p_range[begin])
cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, cache.u; p = p)
sol = solve!(cache)
sols[i] = sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test g(p) sqrt.(p)

f = (res, u, p) -> (res[begin] = u[1] * u[1] - p)
g = function (p_range)
probN = NonlinearProblem{true}(f, [0.5], p_range[begin])
cache = init(probN, NewtonRaphson(); maxiters = 100, abstol=1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, [cache.u[1]]; p = p)
sol = solve!(cache)
sols[i] = sol.u[1]
end
return sols
end
p = range(0.01, 2, length = 200)
@test g(p) sqrt.(p)

# Error Checks

f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
Expand Down

0 comments on commit cf1486f

Please sign in to comment.