Skip to content

Commit

Permalink
Merge pull request #437 from SciML/ap/consolidate_stats
Browse files Browse the repository at this point in the history
Consolidate Stats Handling
  • Loading branch information
avik-pal authored May 25, 2024
2 parents 6f65ac4 + 50aa5c4 commit 630aad4
Show file tree
Hide file tree
Showing 21 changed files with 161 additions and 247 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
Expand All @@ -46,7 +45,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveEnlsipExt = "Enlsip"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
Expand All @@ -67,7 +65,6 @@ BenchmarkTools = "1.4"
CUDA = "5.2"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149.0"
Enlsip = "0.9"
Enzyme = "0.12"
ExplicitImports = "1.4.4"
FastBroadcast = "0.2.8, 0.3"
Expand Down Expand Up @@ -119,7 +116,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
Expand All @@ -145,4 +141,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enlsip", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
1 change: 0 additions & 1 deletion docs/src/basics/nonlinear_solution.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ SciMLBase.NonlinearSolution

```@docs
SciMLBase.NLStats
NonlinearSolve.ImmutableNLStats
```

## Return Code
Expand Down
5 changes: 3 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
istril, istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A
InvPreconditioner, needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!, recursivefill!
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,
Expand Down
3 changes: 1 addition & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ abstract type AbstractNonlinearSolveLineSearchCache end

function reinit_cache!(
cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...)
cache.nf[] = 0
cache.p = p
end

Expand Down Expand Up @@ -235,7 +234,7 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
print(io, " "^(indent) * ")")
end
Expand Down
28 changes: 14 additions & 14 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ end
inv_workspace

# Counters
nf::Int
stats::NLStats
nsteps::Int
nresets::Int
max_resets::Int
Expand Down Expand Up @@ -131,7 +131,7 @@ function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip},
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.nresets = 0
cache.steps_since_last_reset = 0
Expand All @@ -151,8 +151,9 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::ApproximateJacobianSolveAlgorithm,
args...; alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing,
reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing,
args...; stats = empty_nlstats(), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -164,18 +165,17 @@ function SciMLBase.__init(
INV = store_inverse_jacobian(alg.update_rule)

linsolve = get_linear_solver(alg.descent)
initialization_cache = __internal_init(
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
stats, linsolve, maxiters, internalnorm)

abstol, reltol, termination_cache = init_termination_cache(
prob, abstol, reltol, fu, u, termination_condition)
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

J = initialization_cache(nothing)
inv_workspace, J = INV ? __safe_inv_workspace(J) : (nothing, J)
descent_cache = __internal_init(
prob, alg.descent, J, fu, u; abstol, reltol, internalnorm,
linsolve_kwargs, pre_inverted = Val(INV), timer)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, reltol,
internalnorm, linsolve_kwargs, pre_inverted = Val(INV), timer)
du = get_du(descent_cache)

reinit_rule_cache = __internal_init(alg.reinit_rule, J, fu, u, du)
Expand All @@ -192,28 +192,28 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

update_rule_cache = __internal_init(
prob, alg.update_rule, J, fu, u, du; internalnorm)
prob, alg.update_rule, J, fu, u, du; stats, internalnorm)

trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
uses_jacobian_inverse = Val(INV), kwargs...)

return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, initialization_cache,
descent_cache, linesearch_cache, trustregion_cache, update_rule_cache,
reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets,
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0,
termination_cache, trace, ReturnCode.Default, false, false, kwargs)
end
Expand All @@ -223,7 +223,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
new_jacobian = true
@static_timeit cache.timer "jacobian init/reinit" begin
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
if cache.nsteps == 0 # First Step is special ignore kwargs
J_init = __internal_solve!(
cache.initialization_cache, cache.fu, cache.u, Val(false))
if INV
Expand Down
23 changes: 12 additions & 11 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
trustregion_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -135,7 +135,7 @@ function __reinit_internal!(
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -153,9 +153,10 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, maxtime = nothing, termination_condition = nothing,
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, maxtime = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
(; f, u0, p) = prob
Expand All @@ -170,11 +171,11 @@ function SciMLBase.__init(
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

jac_cache = JacobianCache(
prob, alg, f, fu, u, p; autodiff = alg.jacobian_ad, linsolve,
prob, alg, f, fu, u, p; stats, autodiff = alg.jacobian_ad, linsolve,
jvp_autodiff = alg.forward_ad, vjp_autodiff = alg.reverse_ad)
J = jac_cache(nothing)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; abstol, reltol,
internalnorm, linsolve_kwargs, timer)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol,
reltol, internalnorm, linsolve_kwargs, timer)
du = get_du(descent_cache)

if alg.trustregion !== missing && alg.linesearch !== missing
Expand All @@ -189,15 +190,15 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

Expand All @@ -206,7 +207,7 @@ function SciMLBase.__init(

return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer,
trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
20 changes: 8 additions & 12 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
alg::AbstractNonlinearSolveAlgorithm, args...; stats = empty_nlstats(), kwargs...)
cache = SciMLBase.__init(prob, alg, args...; stats, kwargs...)
return solve!(cache)
end

function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && get_nsteps(cache) < cache.maxiters
return !cache.force_stop && cache.nsteps < cache.maxiters
end

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
Expand All @@ -16,21 +16,16 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
cache.retcode = ifelse(
get_nsteps(cache) cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
end

update_from_termination_cache!(cache.termination_cache, cache)

update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
update_trace!(cache.trace, cache.nsteps, get_u(cache),
get_fu(cache), nothing, nothing, nothing; last = True)

return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, stats = __compile_stats(cache), cache.trace)
end

function __compile_stats(cache::AbstractNonlinearSolveCache)
return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
get_nsolve(cache), get_nsteps(cache))
cache.retcode, cache.stats, cache.trace)
end

"""
Expand All @@ -55,7 +50,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
__step!(cache, args...; kwargs...)
end

hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
cache.stats.nsteps += 1
cache.nsteps += 1

if timeit
cache.total_time += time() - time_start
Expand Down
16 changes: 8 additions & 8 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ concrete_jac(::GeneralizedDFSane) = nothing
linesearch_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -106,7 +106,7 @@ function __reinit_internal!(

reset!(cache.trace)
reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -116,9 +116,9 @@ end

@internal_caches GeneralizedDFSaneCache :linesearch_cache

function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, termination_condition = nothing,
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -130,8 +130,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
fu = evaluate_f(prob, u)
@bb fu_cache = copy(fu)

linesearch_cache = __internal_init(
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
stats, maxiters, internalnorm, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
Expand All @@ -150,7 +150,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane

return GeneralizedDFSaneCache{isinplace(prob), maxtime !== nothing}(
fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min),
T(alg.σ_max), linesearch_cache, 0, 0, maxiters, maxtime,
T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime,
timer, 0.0, tc_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
Loading

0 comments on commit 630aad4

Please sign in to comment.