From d643f3c8c703aa86023b27d33198e398bad9c33a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 16 Feb 2024 21:07:07 -0500 Subject: [PATCH] Add pretty printing for the cache --- src/abstract_types.jl | 22 ++++++++++++++++++++-- src/core/generic.jl | 1 + src/default.jl | 25 +++++++++++++++++++------ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 1b30f2b9f..ef03b6a39 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned. """ concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing -function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name} - __show_algorithm(io, alg, name, 0) +function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) + __show_algorithm(io, alg, get_name(alg), 0) end get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name @@ -207,6 +207,24 @@ Abstract Type for all NonlinearSolve.jl Caches. """ abstract type AbstractNonlinearSolveCache{iip, timeit} end +function Base.show(io::IO, cache::AbstractNonlinearSolveCache) + __show_cache(io, cache, 0) +end + +function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0) + println(io, "$(nameof(typeof(cache)))(") + __show_algorithm(io, cache.alg, + (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + + 4) + println(io, ",") + println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",") + println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",") + println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",") + println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",") + println(io, " "^(indent + 4) * "retcode = ", cache.retcode) + print(io, " "^(indent) * ")") +end + SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip get_fu(cache::AbstractNonlinearSolveCache) = cache.fu diff --git a/src/core/generic.jl b/src/core/generic.jl index 70f7badbc..33408e2cc 100644 --- a/src/core/generic.jl +++ b/src/core/generic.jl @@ -47,6 +47,7 @@ Performs one step of the nonlinear solver. """ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...; kwargs...) where {iip, timeit} + not_terminated(cache) || return timeit && (time_start = time()) res = @static_timeit cache.timer "solve" begin __step!(cache, args...; kwargs...) diff --git a/src/default.jl b/src/default.jl index 0f9d920b1..35ca50749 100644 --- a/src/default.jl +++ b/src/default.jl @@ -55,6 +55,19 @@ end maxtime retcode::ReturnCode.T force_stop::Bool + maxiters::Int +end + +function Base.show(io::IO, + cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N} + problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem") + println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms") + best_alg = ifelse(cache.best == -1, "nothing", cache.best) + println(io, "Best algorithm: $(best_alg)") + println(io, "Current algorithm: $(cache.current)") + println(io, "nsteps: $(cache.nsteps)") + println(io, "retcode: $(cache.retcode)") + __show_cache(io, cache.caches[cache.current], 0) end function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...) @@ -68,11 +81,11 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb algType = NonlinearSolvePolyAlgorithm{pType} @eval begin function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...; - maxtime = nothing, kwargs...) where {N} + maxtime = nothing, maxiters = 1000, kwargs...) where {N} return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...), alg.algs), alg, -1, 1, 0, 0.0, maxtime, - ReturnCode.Default, false) + ReturnCode.Default, false, maxiters) end end end @@ -126,8 +139,8 @@ end return Expr(:block, calls...) end -@generated function __step!( - cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N} +@generated function __step!(cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; + kwargs...) where {iip, N} calls = [] cache_syms = [gensym("cache") for i in 1:N] for i in 1:N @@ -136,6 +149,7 @@ end $(cache_syms[i]) = cache.caches[$(i)] if $(i) == cache.current __step!($(cache_syms[i]), args...; kwargs...) + $(cache_syms[i]).nsteps += 1 if !not_terminated($(cache_syms[i])) if SciMLBase.successful_retcode($(cache_syms[i]).retcode) cache.best = $(i) @@ -159,8 +173,7 @@ end cache.force_stop = true return end - end - ) + end) return Expr(:block, calls...) end