Skip to content

Commit

Permalink
fix: reinit! on forwarddiff cache
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2024
1 parent 5c722c0 commit 3ef3ace
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
inner_cache = InternalAPI.reinit!(
InternalAPI.reinit!(
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
)
cache.cache = inner_cache
cache.p = p
cache.values_p = nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
Expand Down
31 changes: 31 additions & 0 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,34 @@ end

@test hess1hess2 atol=1e-3
end

@testitem "reinit! on ForwardDiff cache SciML/NonlinearSolve.jl#391" tags=[:core] begin
using ForwardDiff

function multiple_solves(ps::Vector)
res = similar(ps, 4, length(ps))
for (i, p) in enumerate(ps)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[i])
sol = solve(prob)
res[:, i] .= sol.u
end
return sum(abs2, res)
end

function multiple_solves_cached(ps::Vector)
res = similar(ps, 4, length(ps))
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[1])
cache = init(prob, NewtonRaphson())
for (i, p) in enumerate(ps)
reinit!(cache; p)
sol = solve!(cache)
res[:, i] .= sol.u
end
return sum(abs2, res)
end

ps = collect(1.0:5.0)

@test ForwardDiff.gradient(multiple_solves, ps)
ForwardDiff.gradient(multiple_solves_cached, ps)
end

0 comments on commit 3ef3ace

Please sign in to comment.