Skip to content

Commit

Permalink
Consolidate handling of Stats
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 25, 2024
1 parent b06cb09 commit 24793ed
Show file tree
Hide file tree
Showing 18 changed files with 139 additions and 185 deletions.
5 changes: 3 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
istril, istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A
InvPreconditioner, needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!, recursivefill!
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,
Expand Down
3 changes: 1 addition & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ abstract type AbstractNonlinearSolveLineSearchCache end

function reinit_cache!(
cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...)
cache.nf[] = 0
cache.p = p
end

Expand Down Expand Up @@ -235,7 +234,7 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
print(io, " "^(indent) * ")")
end
Expand Down
23 changes: 12 additions & 11 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ end
inv_workspace

# Counters
nf::Int
stats::NLStats
nsteps::Int
nresets::Int
max_resets::Int
Expand Down Expand Up @@ -131,7 +131,7 @@ function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip},
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.nresets = 0
cache.steps_since_last_reset = 0
Expand All @@ -151,8 +151,9 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::ApproximateJacobianSolveAlgorithm,
args...; alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing,
reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing,
args...; stats = empty_nlstats(), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -165,7 +166,7 @@ function SciMLBase.__init(

linsolve = get_linear_solver(alg.descent)
initialization_cache = __internal_init(
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
prob, alg.initialization, alg, f, fu, u, p; stats, linsolve, maxiters, internalnorm)

abstol, reltol, termination_cache = init_termination_cache(
prob, abstol, reltol, fu, u, termination_condition)
Expand All @@ -174,7 +175,7 @@ function SciMLBase.__init(
J = initialization_cache(nothing)
inv_workspace, J = INV ? __safe_inv_workspace(J) : (nothing, J)
descent_cache = __internal_init(
prob, alg.descent, J, fu, u; abstol, reltol, internalnorm,
prob, alg.descent, J, fu, u; stats, abstol, reltol, internalnorm,
linsolve_kwargs, pre_inverted = Val(INV), timer)
du = get_du(descent_cache)

Expand All @@ -192,28 +193,28 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

update_rule_cache = __internal_init(
prob, alg.update_rule, J, fu, u, du; internalnorm)
prob, alg.update_rule, J, fu, u, du; stats, internalnorm)

trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
uses_jacobian_inverse = Val(INV), kwargs...)

return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, initialization_cache,
descent_cache, linesearch_cache, trustregion_cache, update_rule_cache,
reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets,
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0,
termination_cache, trace, ReturnCode.Default, false, false, kwargs)
end
Expand All @@ -223,7 +224,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
new_jacobian = true
@static_timeit cache.timer "jacobian init/reinit" begin
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
if cache.nsteps == 0 # First Step is special ignore kwargs
J_init = __internal_solve!(
cache.initialization_cache, cache.fu, cache.u, Val(false))
if INV
Expand Down
16 changes: 8 additions & 8 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
trustregion_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -135,7 +135,7 @@ function __reinit_internal!(
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -153,7 +153,7 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
args...; stats=empty_nlstats(), alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, maxtime = nothing, termination_condition = nothing,
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
timer = get_timer_output()
Expand All @@ -170,10 +170,10 @@ function SciMLBase.__init(
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

jac_cache = JacobianCache(
prob, alg, f, fu, u, p; autodiff = alg.jacobian_ad, linsolve,
prob, alg, f, fu, u, p; stats, autodiff = alg.jacobian_ad, linsolve,
jvp_autodiff = alg.forward_ad, vjp_autodiff = alg.reverse_ad)
J = jac_cache(nothing)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; abstol, reltol,
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, reltol,
internalnorm, linsolve_kwargs, timer)
du = get_du(descent_cache)

Expand All @@ -189,15 +189,15 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

Expand All @@ -206,7 +206,7 @@ function SciMLBase.__init(

return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer,
trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
20 changes: 8 additions & 12 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
alg::AbstractNonlinearSolveAlgorithm, args...; stats=empty_nlstats(), kwargs...)
cache = SciMLBase.__init(prob, alg, args...; stats, kwargs...)
return solve!(cache)
end

function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && get_nsteps(cache) < cache.maxiters
return !cache.force_stop && cache.nsteps < cache.maxiters
end

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
Expand All @@ -16,21 +16,16 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
cache.retcode = ifelse(
get_nsteps(cache) cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
end

update_from_termination_cache!(cache.termination_cache, cache)

update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
update_trace!(cache.trace, cache.nsteps, get_u(cache),
get_fu(cache), nothing, nothing, nothing; last = True)

return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, stats = __compile_stats(cache), cache.trace)
end

function __compile_stats(cache::AbstractNonlinearSolveCache)
return SciMLBase.NLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
get_nsolve(cache), get_nsteps(cache))
cache.retcode, cache.stats, cache.trace)
end

"""
Expand All @@ -55,7 +50,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
__step!(cache, args...; kwargs...)
end

hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
cache.stats.nsteps += 1
cache.nsteps += 1

if timeit
cache.total_time += time() - time_start
Expand Down
16 changes: 8 additions & 8 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ concrete_jac(::GeneralizedDFSane) = nothing
linesearch_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -106,7 +106,7 @@ function __reinit_internal!(

reset!(cache.trace)
reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -116,9 +116,9 @@ end

@internal_caches GeneralizedDFSaneCache :linesearch_cache

function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, termination_condition = nothing,
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -130,8 +130,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
fu = evaluate_f(prob, u)
@bb fu_cache = copy(fu)

linesearch_cache = __internal_init(
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
stats, maxiters, internalnorm, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
Expand All @@ -150,7 +150,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane

return GeneralizedDFSaneCache{isinplace(prob), maxtime !== nothing}(
fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min),
T(alg.σ_max), linesearch_cache, 0, 0, maxiters, maxtime,
T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime,
timer, 0.0, tc_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
20 changes: 12 additions & 8 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ end
best::Int
current::Int
nsteps::Int
stats::NLStats
total_time::Float64
maxtime
retcode::ReturnCode.T
Expand Down Expand Up @@ -90,6 +91,7 @@ end
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
cache.current = cache.alg.start_index
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.total_time = 0.0
end
Expand All @@ -98,8 +100,8 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
maxiters = 1000, internalnorm = DEFAULT_NORM,
prob::$probType, alg::$algType{N}, args...; stats = empty_nlstats(),
maxtime = nothing, maxiters = 1000, internalnorm = DEFAULT_NORM,
alias_u0 = false, verbose = true, kwargs...) where {N}
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
Expand All @@ -115,13 +117,14 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
solver -> SciMLBase.__init(prob, solver, args...; stats, maxtime,
internalnorm, alias_u0, verbose, kwargs...),
alg.algs),
alg,
-1,
alg.start_index,
0,
stats,
0.0,
maxtime,
ReturnCode.Default,
Expand Down Expand Up @@ -181,7 +184,6 @@ end
push!(calls, quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])
end)
for i in 1:N
push!(calls, quote
Expand All @@ -203,7 +205,7 @@ end
end
return __build_solution_less_specialize(
cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats, cache.caches[idx].trace)
retcode, stats = cache.stats, cache.caches[idx].trace)
end)

return Expr(:block, calls...)
Expand Down Expand Up @@ -250,7 +252,8 @@ end
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
@generated function SciMLBase.__solve(
prob::$probType, alg::$algType{N}, args...; stats = empty_nlstats(),
alias_u0 = false, verbose = true, kwargs...) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
Expand Down Expand Up @@ -280,8 +283,9 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
else
$(prob_syms[i]) = prob
end
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
args...; alias_u0, verbose, kwargs...)
$(cur_sol) = SciMLBase.__solve(
$(prob_syms[i]), alg.algs[$(i)], args...;
stats, alias_u0, verbose, kwargs...)
if SciMLBase.successful_retcode($(cur_sol))
if alias_u0
copyto!(u0, $(cur_sol).u)
Expand Down
Loading

0 comments on commit 24793ed

Please sign in to comment.