Skip to content

Commit

Permalink
Make solve for FastShortcutNonlinearPolyalg type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2023
1 parent a1cb64c commit dfe759a
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::RobustMultiNe
TrustRegion(; linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Fan, adkwargs...))

# Partially Type Unstable but can't do much since some upstream caches -- LineSearches
# and SparseDiffTools cause the instability
return RobustMultiNewtonCache{iip}(map(solver -> SciMLBase.__init(prob, solver, args...;
kwargs...), algs), alg, 1)
end
Expand Down Expand Up @@ -139,36 +141,56 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip},
end

# This version doesn't allocate all the caches!
function SciMLBase.__solve(prob::NonlinearProblem{uType, iip},
@generated function SciMLBase.__solve(prob::NonlinearProblem{uType, iip},
alg::FastShortcutNonlinearPolyalg, args...; kwargs...) where {uType, iip}
@unpack adkwargs, linsolve, precs = alg
calls = [:(@unpack adkwargs, linsolve, precs = alg)]

algs = [
iip ? Klement() : nothing, # Klement not yet implemented for IIP
iip ? Broyden() : nothing, # Broyden not yet implemented for IIP
NewtonRaphson(; linsolve, precs, adkwargs...),
NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...),
TrustRegion(; linsolve, precs, adkwargs...),
TrustRegion(; linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...),
!iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP
!iip ? :(Broyden()) : nothing, # Broyden not yet implemented for IIP
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
:(TrustRegion(; linsolve, precs, adkwargs...)),
:(TrustRegion(; linsolve, precs,
radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...)),
]
filter!(!isnothing, algs)

sols = Vector{SciMLBase.NonlinearSolution}(undef, length(algs))

for (i, solver) in enumerate(algs)
sols[i] = SciMLBase.__solve(prob, solver, args...; kwargs...)
if SciMLBase.successful_retcode(sols[i])
return SciMLBase.build_solution(prob, alg, sols[i].u, sols[i].resid;
sols[i].retcode, sols[i].stats, original = sols[i])
end
counter = 1
sol_syms = [gensym("sol") for i in 1:length(algs)]
for i in 1:length(algs)
cur_sol = sol_syms[i]
push!(calls,
quote
$(cur_sol) = SciMLBase.__solve(prob, $(algs[i]), args...; kwargs...)
if SciMLBase.successful_retcode($(cur_sol))
return SciMLBase.build_solution(prob, alg, $(cur_sol).u,
$(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats,
original = $(cur_sol))
end
end)
end

resids = map(Base.Fix2(getproperty, resid), sols)
minfu, idx = findmin(DEFAULT_NORM, resids)
resids = map(x -> "$x.resid", sol_syms)

push!(calls,
quote
resids = $(Tuple(resids))
minfu, idx = findmin(DEFAULT_NORM, resids)

Check warning on line 178 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L177-L178

Added lines #L177 - L178 were not covered by tests
end)

for i in 1:length(algs)
push!(calls,
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,

Check warning on line 185 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats,
original = $(sol_syms[i]))
end
end)
end
push!(calls, :(error("Current choices shouldn't get here!")))

return SciMLBase.build_solution(prob, alg, sols[idx].u, sols[idx].resid;
sols[idx].retcode, sols[idx].stats, original = sols[idx])
return Expr(:block, calls...)
end

## General shared polyalg functions
Expand Down

0 comments on commit dfe759a

Please sign in to comment.