diff --git a/ext/NonlinearSolveNLSolversExt.jl b/ext/NonlinearSolveNLSolversExt.jl index fd75095c0..b480578d0 100644 --- a/ext/NonlinearSolveNLSolversExt.jl +++ b/ext/NonlinearSolveNLSolversExt.jl @@ -4,8 +4,8 @@ using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra using FiniteDiff, ForwardDiff function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false, - termination_condition = nothing, kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0::Bool = false, termination_condition = nothing, kwargs...) NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL) abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0)) @@ -50,12 +50,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; prob_nlsolver = NEqProblem(prob_obj; inplace = false) res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options) - retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, - ReturnCode.MaxIters) + retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, + ReturnCode.Success, ReturnCode.MaxIters) stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter) - return SciMLBase.build_solution(prob, alg, res.info.solution, - res.info.best_residual; retcode, original = res, stats) + return SciMLBase.build_solution( + prob, alg, res.info.solution, res.info.best_residual; + retcode, original = res, stats) end f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0) @@ -73,12 +74,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...; res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options) - retcode = ifelse(norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, - ReturnCode.MaxIters) + retcode = ifelse( + norm(res.info.best_residual, Inf) ≤ abstol, ReturnCode.Success, ReturnCode.MaxIters) stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter) - return SciMLBase.build_solution(prob, alg, res.info.solution, - res.info.best_residual; retcode, original = res, stats) + return SciMLBase.build_solution(prob, alg, res.info.solution, res.info.best_residual; + retcode, original = res, stats) end end diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 78063c19d..e9c2f4d7f 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -214,8 +214,7 @@ end function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0) println(io, "$(nameof(typeof(cache)))(") __show_algorithm(io, cache.alg, - (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + - 4) + (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4) println(io, ",") println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",") println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",") diff --git a/src/core/generic.jl b/src/core/generic.jl index 7f3dd3a8e..9aafba49b 100644 --- a/src/core/generic.jl +++ b/src/core/generic.jl @@ -45,8 +45,8 @@ Performs one step of the nonlinear solver. respectively. For algorithms that don't use jacobian information, this keyword is ignored with a one-time warning. """ -function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...; - kwargs...) where {iip, timeit} +function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, + args...; kwargs...) where {iip, timeit} not_terminated(cache) || return timeit && (time_start = time()) res = @static_timeit cache.timer "solve" begin diff --git a/src/default.jl b/src/default.jl index 8544998f7..da2420656 100644 --- a/src/default.jl +++ b/src/default.jl @@ -58,8 +58,8 @@ end maxiters::Int end -function Base.show(io::IO, - cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N} +function Base.show( + io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N} problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem") println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms") best_alg = ifelse(cache.best == -1, "nothing", cache.best) @@ -84,8 +84,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb maxtime = nothing, maxiters = 1000, kwargs...) where {N} return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...), - alg.algs), alg, -1, 1, 0, 0.0, maxtime, - ReturnCode.Default, false, maxiters) + alg.algs), + alg, + -1, + 1, + 0, + 0.0, + maxtime, + ReturnCode.Default, + false, + maxiters) end end end @@ -109,8 +117,9 @@ end stats = $(sol_syms[i]).stats u = $(sol_syms[i]).u fu = get_fu($(cache_syms[i])) - return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u, - fu; retcode = $(sol_syms[i]).retcode, stats, + return SciMLBase.build_solution( + $(sol_syms[i]).prob, cache.alg, u, fu; + retcode = $(sol_syms[i]).retcode, stats, original = $(sol_syms[i]), trace = $(sol_syms[i]).trace) end cache.current = $(i + 1) @@ -137,8 +146,8 @@ end return Expr(:block, calls...) end -@generated function __step!(cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; - kwargs...) where {iip, N} +@generated function __step!( + cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N} calls = [] cache_syms = [gensym("cache") for i in 1:N] for i in 1:N diff --git a/test/wrappers/rootfind_tests.jl b/test/wrappers/rootfind_tests.jl index 0fa56d690..dcee9ceba 100644 --- a/test/wrappers/rootfind_tests.jl +++ b/test/wrappers/rootfind_tests.jl @@ -16,7 +16,8 @@ end prob_iip = SteadyStateProblem(f_iip, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] sol = solve(prob_iip, alg) @test SciMLBase.successful_retcode(sol.retcode) @test maximum(abs, sol.resid) < 1e-6 @@ -28,7 +29,8 @@ end prob_oop = SteadyStateProblem(f_oop, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] sol = solve(prob_oop, alg) @test SciMLBase.successful_retcode(sol.retcode) @test maximum(abs, sol.resid) < 1e-6 @@ -45,7 +47,8 @@ end prob_iip = NonlinearProblem{true}(f_iip, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] local sol sol = solve(prob_iip, alg) @test SciMLBase.successful_retcode(sol.retcode) @@ -57,7 +60,8 @@ end u0 = zeros(2) prob_oop = NonlinearProblem{false}(f_oop, u0) for alg in [ - NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] + NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()] local sol sol = solve(prob_oop, alg) @test SciMLBase.successful_retcode(sol.retcode) @@ -70,8 +74,7 @@ end for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15], alg in [ NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), - NLsolveJL(), - CMINPACK(), SIAMFANLEquationsJL(; method = :newton), + NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton), SIAMFANLEquationsJL(; method = :pseudotransient), SIAMFANLEquationsJL(; method = :secant)]