Skip to content

Commit

Permalink
Add pretty printing for the cache
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 18, 2024
1 parent 84a08dc commit c7409c8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
22 changes: 20 additions & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned.
"""
concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name}
__show_algorithm(io, alg, name, 0)
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
__show_algorithm(io, alg, get_name(alg), 0)
end

get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name
Expand Down Expand Up @@ -207,6 +207,24 @@ Abstract Type for all NonlinearSolve.jl Caches.
"""
abstract type AbstractNonlinearSolveCache{iip, timeit} end

function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
__show_cache(io, cache, 0)
end

function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, "$(nameof(typeof(cache)))(")
__show_algorithm(io, cache.alg,
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent +
4)
println(io, ",")
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
print(io, " "^(indent) * ")")
end

SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
Expand Down
5 changes: 3 additions & 2 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ Performs one step of the nonlinear solver.
respectively. For algorithms that don't use jacobian information, this keyword is
ignored with a one-time warning.
"""
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
args...; kwargs...) where {iip, timeit}
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...;
kwargs...) where {iip, timeit}
not_terminated(cache) || return
timeit && (time_start = time())
res = @static_timeit cache.timer "solve" begin
__step!(cache, args...; kwargs...)
Expand Down
25 changes: 19 additions & 6 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ end
maxtime
retcode::ReturnCode.T
force_stop::Bool
maxiters::Int
end

function Base.show(io::IO,
cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")
println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms")
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
println(io, "Best algorithm: $(best_alg)")
println(io, "Current algorithm: $(cache.current)")
println(io, "nsteps: $(cache.nsteps)")
println(io, "retcode: $(cache.retcode)")
__show_cache(io, cache.caches[cache.current], 0)
end

function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
Expand All @@ -68,11 +81,11 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
maxtime = nothing, kwargs...) where {N}
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
alg.algs), alg, -1, 1, 0, 0.0, maxtime,
ReturnCode.Default, false)
ReturnCode.Default, false, maxiters)
end
end
end
Expand Down Expand Up @@ -124,8 +137,8 @@ end
return Expr(:block, calls...)
end

@generated function __step!(
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
@generated function __step!(cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...;
kwargs...) where {iip, N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
Expand All @@ -134,6 +147,7 @@ end
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
__step!($(cache_syms[i]), args...; kwargs...)
$(cache_syms[i]).nsteps += 1
if !not_terminated($(cache_syms[i]))
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
cache.best = $(i)
Expand All @@ -157,8 +171,7 @@ end
cache.force_stop = true
return
end
end
)
end)

return Expr(:block, calls...)
end
Expand Down

0 comments on commit c7409c8

Please sign in to comment.