From 0ccb0ea1e8aed5541e09593d438bce1f83bfcd6b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 15:14:59 -0400 Subject: [PATCH] fix: reinit! on forwarddiff cache --- src/forward_diff.jl | 3 +-- test/forward_ad_tests.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/forward_diff.jl b/src/forward_diff.jl index d34bca877..5bb98561c 100644 --- a/src/forward_diff.jl +++ b/src/forward_diff.jl @@ -37,10 +37,9 @@ function InternalAPI.reinit!( cache::NonlinearSolveForwardDiffCache, args...; p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... ) - inner_cache = InternalAPI.reinit!( + InternalAPI.reinit!( cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... ) - cache.cache = inner_cache cache.p = p cache.values_p = nodual_value(p) cache.partials_p = ForwardDiff.partials(p) diff --git a/test/forward_ad_tests.jl b/test/forward_ad_tests.jl index 942d66c9b..f3cf74bae 100644 --- a/test/forward_ad_tests.jl +++ b/test/forward_ad_tests.jl @@ -218,3 +218,34 @@ end @test hess1≈hess2 atol=1e-3 end + +@testitem "reinit! on ForwardDiff cache SciML/NonlinearSolve.jl#391" tags=[:core] begin + using ForwardDiff + + function multiple_solves(ps::Vector) + res = similar(ps, 4, length(ps)) + for (i, p) in enumerate(ps) + prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[i]) + sol = solve(prob) + res[:, i] .= sol.u + end + return sum(abs2, res) + end + + function multiple_solves_cached(ps::Vector) + res = similar(ps, 4, length(ps)) + prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[1]) + cache = init(prob, NewtonRaphson()) + for (i, p) in enumerate(ps) + reinit!(cache; p) + sol = solve!(cache) + res[:, i] .= sol.u + end + return sum(abs2, res) + end + + ps = collect(1.0:5.0) + + @test ForwardDiff.gradient(multiple_solves, ps) ≈ + ForwardDiff.gradient(multiple_solves_cached, ps) +end