From a67d6edad5a5c66e2647ebea5f97917c61525051 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 20 Feb 2024 11:05:18 -0500 Subject: [PATCH] Add caching for solvers without init --- Project.toml | 2 +- ext/NonlinearSolveLeastSquaresOptimExt.jl | 4 +++ ext/NonlinearSolveNLSolversExt.jl | 21 +++++++------ src/NonlinearSolve.jl | 1 + src/abstract_types.jl | 3 ++ src/core/noinit.jl | 37 +++++++++++++++++++++++ test/misc/noinit_caching_tests.jl | 23 ++++++++++++++ test/wrappers/rootfind_tests.jl | 15 +++++---- 8 files changed, 89 insertions(+), 17 deletions(-) create mode 100644 src/core/noinit.jl create mode 100644 test/misc/noinit_caching_tests.jl diff --git a/Project.toml b/Project.toml index e5afe2c2e..20dd03f68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.6.0" +version = "3.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/NonlinearSolveLeastSquaresOptimExt.jl b/ext/NonlinearSolveLeastSquaresOptimExt.jl index 42ff7c9d3..b5c1a7426 100644 --- a/ext/NonlinearSolveLeastSquaresOptimExt.jl +++ b/ext/NonlinearSolveLeastSquaresOptimExt.jl @@ -23,6 +23,10 @@ end kwargs end +function Base.show(io::IO, cache::LeastSquaresOptimJLCache) + print(io, "LeastSquaresOptimJLCache()") +end + function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...) error("Reinitialization not supported for LeastSquaresOptimJL.") end diff --git a/ext/NonlinearSolveNLSolversExt.jl b/ext/NonlinearSolveNLSolversExt.jl index fd75095c0..b480578d0 100644 --- a/ext/NonlinearSolveNLSolversExt.jl +++ b/ext/NonlinearSolveNLSolversExt.jl @@ -4,8 +4,8 @@ using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra using FiniteDiff, ForwardDiff function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0::Bool = false, termination_condition = nothing, kwargs...) NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL) abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0)) @@ -50,12 +50,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; prob_nlsolver = NEqProblem(prob_obj; inplace = false) res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options) - retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, - ReturnCode.MaxIters) + retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, + ReturnCode.Success, ReturnCode.MaxIters) stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter) - return SciMLBase.build_solution(prob, alg, res.info.solution, - res.info.best_residual; retcode, original = res, stats) + return SciMLBase.build_solution( + prob, alg, res.info.solution, res.info.best_residual; + retcode, original = res, stats) end f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0) @@ -73,12 +74,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options) - retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, - ReturnCode.MaxIters) + retcode = ifelse( + norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, ReturnCode.MaxIters) stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter) - return SciMLBase.build_solution(prob, alg, res.info.solution, - res.info.best_residual; retcode, original = res, stats) + return SciMLBase.build_solution(prob, alg, res.info.solution, res.info.best_residual; + retcode, original = res, stats) end end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 10645fad7..991783672 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -66,6 +66,7 @@ include("core/generic.jl") include("core/approximate_jacobian.jl") include("core/generalized_first_order.jl") include("core/spectral_methods.jl") +include("core/noinit.jl") include("algorithms/raphson.jl") include("algorithms/pseudo_transient.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index b7127cb2c..a08f53317 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -214,6 +214,9 @@ get_u(cache::AbstractNonlinearSolveCache) = cache.u set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu) SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u) +function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache; kwargs...) + return reinit_cache!(cache; kwargs...) +end function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache, u0; kwargs...) return reinit_cache!(cache; u0, kwargs...) end diff --git a/src/core/noinit.jl b/src/core/noinit.jl new file mode 100644 index 000000000..b51c09c23 --- /dev/null +++ b/src/core/noinit.jl @@ -0,0 +1,37 @@ +# Some algorithms don't support creating a cache and doing `solve!`, this unfortunately +# makes it difficult to write generic code that supports caching. For the algorithms that +# don't have a `__init` function defined, we create a "Fake Cache", which just calls +# `__solve` from `solve!` +@concrete mutable struct NonlinearSolveNoInitCache{iip, timeit} <: + AbstractNonlinearSolveCache{iip, timeit} + prob + alg + args + kwargs::Any +end + +function SciMLBase.reinit!( + cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs...) + prob = remake(cache.prob; u0, p) + cache.prob = prob + cache.kwargs = merge(cache.kwargs, kwargs) + return cache +end + +function Base.show(io::IO, cache::NonlinearSolveNoInitCache) + print(io, "NonlinearSolveNoInitCache(alg = $(cache.alg))") +end + +function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip}, + alg::Union{AbstractNonlinearSolveAlgorithm, + SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm}, + args...; + maxtime = nothing, + kwargs...) where {uType, iip} + return NonlinearSolveNoInitCache{iip, maxtime !== nothing}( + prob, alg, args, merge((; maxtime), kwargs)) +end + +function SciMLBase.solve!(cache::NonlinearSolveNoInitCache) + return solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) +end diff --git a/test/misc/noinit_caching_tests.jl b/test/misc/noinit_caching_tests.jl new file mode 100644 index 000000000..f9207d82f --- /dev/null +++ b/test/misc/noinit_caching_tests.jl @@ -0,0 +1,23 @@ +@testitem "NoInit Caching" begin + using LinearAlgebra + import NLsolve, NLSolvers + + solvers = [SimpleNewtonRaphson(), SimpleTrustRegion(), SimpleDFSane(), NLsolveJL(), + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking()))] + + prob = NonlinearProblem((u, p) -> u .^ 2 .- p, [0.1, 0.3], 2.0) + + for alg in solvers + cache = init(prob, alg) + sol = solve!(cache) + @test SciMLBase.successful_retcode(sol) + @test norm(sol.resid, Inf) ≤ 1e-6 + + reinit!(cache; p = 5.0) + @test cache.prob.p == 5.0 + sol = solve!(cache) + @test SciMLBase.successful_retcode(sol) + @test norm(sol.resid, Inf) ≤ 1e-6 + @test norm(sol.u .^ 2 .- 5.0, Inf) ≤ 1e-6 + end +end diff --git a/test/wrappers/rootfind_tests.jl b/test/wrappers/rootfind_tests.jl index 0fa56d690..dcee9ceba 100644 --- a/test/wrappers/rootfind_tests.jl +++ b/test/wrappers/rootfind_tests.jl @@ -16,7 +16,8 @@ end prob_iip = SteadyStateProblem(f_iip, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] sol = solve(prob_iip, alg) @test SciMLBase.successful_retcode(sol.retcode) @test maximum(abs, sol.resid) < 1e-6 @@ -28,7 +29,8 @@ end prob_oop = SteadyStateProblem(f_oop, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] sol = solve(prob_oop, alg) @test SciMLBase.successful_retcode(sol.retcode) @test maximum(abs, sol.resid) < 1e-6 @@ -45,7 +47,8 @@ end prob_iip = NonlinearProblem{true}(f_iip, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] local sol sol = solve(prob_iip, alg) @test SciMLBase.successful_retcode(sol.retcode) @@ -57,7 +60,8 @@ end u0 = zeros(2) prob_oop = NonlinearProblem{false}(f_oop, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] local sol sol = solve(prob_oop, alg) @test SciMLBase.successful_retcode(sol.retcode) @@ -70,8 +74,7 @@ end for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15], alg in [ NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), - NLsolveJL(), - CMINPACK(), SIAMFANLEquationsJL(; method = :newton), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton), SIAMFANLEquationsJL(; method = :pseudotransient), SIAMFANLEquationsJL(; method = :secant)]