Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 18, 2024
1 parent c7409c8 commit bd5d988
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
21 changes: 11 additions & 10 deletions ext/NonlinearSolveNLSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra
using FiniteDiff, ForwardDiff

function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
termination_condition = nothing, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0::Bool = false, termination_condition = nothing, kwargs...)
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)

abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
Expand Down Expand Up @@ -50,12 +50,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
prob_nlsolver = NEqProblem(prob_obj; inplace = false)
res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options)

retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
ReturnCode.MaxIters)
retcode = ifelse(norm(res.info.best_residual, Inf) abstol,
ReturnCode.Success, ReturnCode.MaxIters)
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)

return SciMLBase.build_solution(prob, alg, res.info.solution,
res.info.best_residual; retcode, original = res, stats)
return SciMLBase.build_solution(
prob, alg, res.info.solution, res.info.best_residual;
retcode, original = res, stats)
end

f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
Expand All @@ -73,12 +74,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;

res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options)

retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
ReturnCode.MaxIters)
retcode = ifelse(
norm(res.info.best_residual, Inf) abstol, ReturnCode.Success, ReturnCode.MaxIters)
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)

return SciMLBase.build_solution(prob, alg, res.info.solution,
res.info.best_residual; retcode, original = res, stats)
return SciMLBase.build_solution(prob, alg, res.info.solution, res.info.best_residual;
retcode, original = res, stats)
end

end
3 changes: 1 addition & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ end
function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, "$(nameof(typeof(cache)))(")
__show_algorithm(io, cache.alg,
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent +
4)
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)
println(io, ",")
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
Expand Down
4 changes: 2 additions & 2 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ Performs one step of the nonlinear solver.
respectively. For algorithms that don't use jacobian information, this keyword is
ignored with a one-time warning.
"""
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...;
kwargs...) where {iip, timeit}
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
args...; kwargs...) where {iip, timeit}
not_terminated(cache) || return
timeit && (time_start = time())
res = @static_timeit cache.timer "solve" begin
Expand Down
25 changes: 17 additions & 8 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ end
maxiters::Int
end

function Base.show(io::IO,
cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
function Base.show(

Check warning on line 61 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L61

Added line #L61 was not covered by tests
io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")
println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms")
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
Expand All @@ -84,8 +84,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
alg.algs), alg, -1, 1, 0, 0.0, maxtime,
ReturnCode.Default, false, maxiters)
alg.algs),
alg,
-1,
1,
0,
0.0,
maxtime,
ReturnCode.Default,
false,
maxiters)
end
end
end
Expand All @@ -109,8 +117,9 @@ end
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
fu; retcode = $(sol_syms[i]).retcode, stats,
return SciMLBase.build_solution(
$(sol_syms[i]).prob, cache.alg, u, fu;
retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
end
cache.current = $(i + 1)
Expand All @@ -137,8 +146,8 @@ end
return Expr(:block, calls...)
end

@generated function __step!(cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...;
kwargs...) where {iip, N}
@generated function __step!(
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
Expand Down
15 changes: 9 additions & 6 deletions test/wrappers/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ end
prob_iip = SteadyStateProblem(f_iip, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
sol = solve(prob_iip, alg)
@test SciMLBase.successful_retcode(sol.retcode)
@test maximum(abs, sol.resid) < 1e-6
Expand All @@ -28,7 +29,8 @@ end
prob_oop = SteadyStateProblem(f_oop, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
sol = solve(prob_oop, alg)
@test SciMLBase.successful_retcode(sol.retcode)
@test maximum(abs, sol.resid) < 1e-6
Expand All @@ -45,7 +47,8 @@ end
prob_iip = NonlinearProblem{true}(f_iip, u0)

for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
local sol
sol = solve(prob_iip, alg)
@test SciMLBase.successful_retcode(sol.retcode)
Expand All @@ -57,7 +60,8 @@ end
u0 = zeros(2)
prob_oop = NonlinearProblem{false}(f_oop, u0)
for alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
local sol
sol = solve(prob_oop, alg)
@test SciMLBase.successful_retcode(sol.retcode)
Expand All @@ -70,8 +74,7 @@ end
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15],
alg in [
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
NLsolveJL(),
CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
SIAMFANLEquationsJL(; method = :pseudotransient),
SIAMFANLEquationsJL(; method = :secant)]

Expand Down

0 comments on commit bd5d988

Please sign in to comment.