Skip to content

Commit

Permalink
Add caching for solvers without init
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 20, 2024
1 parent bf072d2 commit c756d35
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.6.0"
version = "3.7.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 4 additions & 0 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ end
kwargs
end

function Base.show(io::IO, cache::LeastSquaresOptimJLCache)
print(io, "LeastSquaresOptimJLCache()")
end

function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...)
error("Reinitialization not supported for LeastSquaresOptimJL.")
end
Expand Down
21 changes: 11 additions & 10 deletions ext/NonlinearSolveNLSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra
using FiniteDiff, ForwardDiff

function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0::Bool = false, termination_condition = nothing, kwargs...)
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)

abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
Expand Down Expand Up @@ -50,12 +50,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
prob_nlsolver = NEqProblem(prob_obj; inplace = false)
res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options)

retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
ReturnCode.MaxIters)
retcode = ifelse(norm(res.info.best_residual, Inf) abstol,
ReturnCode.Success, ReturnCode.MaxIters)
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)

return SciMLBase.build_solution(prob, alg, res.info.solution,
res.info.best_residual; retcode, original = res, stats)
return SciMLBase.build_solution(
prob, alg, res.info.solution, res.info.best_residual;
retcode, original = res, stats)
end

f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
Expand All @@ -73,12 +74,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;

res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options)

retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
ReturnCode.MaxIters)
retcode = ifelse(
norm(res.info.best_residual, Inf) abstol, ReturnCode.Success, ReturnCode.MaxIters)
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)

return SciMLBase.build_solution(prob, alg, res.info.solution,
res.info.best_residual; retcode, original = res, stats)
return SciMLBase.build_solution(prob, alg, res.info.solution, res.info.best_residual;
retcode, original = res, stats)
end

end
1 change: 1 addition & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ include("core/generic.jl")
include("core/approximate_jacobian.jl")
include("core/generalized_first_order.jl")
include("core/spectral_methods.jl")
include("core/noinit.jl")

include("algorithms/raphson.jl")
include("algorithms/pseudo_transient.jl")
Expand Down
3 changes: 3 additions & 0 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ get_u(cache::AbstractNonlinearSolveCache) = cache.u
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache; kwargs...)
return reinit_cache!(cache; kwargs...)
end
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache, u0; kwargs...)
return reinit_cache!(cache; u0, kwargs...)
end
Expand Down
37 changes: 37 additions & 0 deletions src/core/noinit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Some algorithms don't support creating a cache and doing `solve!`, this unfortunately
# makes it difficult to write generic code that supports caching. For the algorithms that
# don't have a `__init` function defined, we create a "Fake Cache", which just calls
# `__solve` from `solve!`
@concrete mutable struct NonlinearSolveNoInitCache{iip, timeit} <:
AbstractNonlinearSolveCache{iip, timeit}
prob
alg
args
kwargs::Any
end

function SciMLBase.reinit!(cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0;
p = cache.prob.p, kwargs...)
prob = remake(cache.prob; u0, p)
cache.prob = prob
cache.kwargs = merge(cache.kwargs, kwargs)
return cache
end

function Base.show(io::IO, cache::NonlinearSolveNoInitCache)
print(io, "NonlinearSolveNoInitCache(alg = $(cache.alg))")
end

function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
alg::Union{AbstractNonlinearSolveAlgorithm,
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm},
args...;
maxtime = nothing,
kwargs...) where {uType, iip}
return NonlinearSolveNoInitCache{iip, maxtime !== nothing}(
prob, alg, args, merge((; maxtime), kwargs))
end

function SciMLBase.solve!(cache::NonlinearSolveNoInitCache)
return solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
end
23 changes: 23 additions & 0 deletions test/misc/noinit_caching_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@testitem "NoInit Caching" begin
using LinearAlgebra
import NLsolve, NLSolvers

solvers = [SimpleNewtonRaphson(), SimpleTrustRegion(), SimpleDFSane(), NLsolveJL(),
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking()))]

prob = NonlinearProblem((u, p) -> u .^ 2 .- p, [0.1, 0.3], 2.0)

for alg in solvers
cache = init(prob, alg)
sol = solve!(cache)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid, Inf) 1e-6

reinit!(cache; p = 5.0)
@test cache.prob.p == 5.0
sol = solve!(cache)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid, Inf) 1e-6
@test norm(sol.u .^ 2 .- 5.0, Inf) 1e-6
end
end
15 changes: 9 additions & 6 deletions test/wrappers/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ end
prob_iip = SteadyStateProblem(f_iip, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
sol = solve(prob_iip, alg)
@test SciMLBase.successful_retcode(sol.retcode)
@test maximum(abs, sol.resid) < 1e-6
Expand All @@ -28,7 +29,8 @@ end
prob_oop = SteadyStateProblem(f_oop, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
sol = solve(prob_oop, alg)
@test SciMLBase.successful_retcode(sol.retcode)
@test maximum(abs, sol.resid) < 1e-6
Expand All @@ -45,7 +47,8 @@ end
prob_iip = NonlinearProblem{true}(f_iip, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
local sol
sol = solve(prob_iip, alg)
@test SciMLBase.successful_retcode(sol.retcode)
Expand All @@ -57,7 +60,8 @@ end
u0 = zeros(2)
prob_oop = NonlinearProblem{false}(f_oop, u0)
for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
local sol
sol = solve(prob_oop, alg)
@test SciMLBase.successful_retcode(sol.retcode)
Expand All @@ -70,8 +74,7 @@ end
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15],
alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(),
CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
SIAMFANLEquationsJL(; method = :pseudotransient),
SIAMFANLEquationsJL(; method = :secant)]

Expand Down

0 comments on commit c756d35

Please sign in to comment.