Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate Stats Handling #437

Merged
merged 4 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
Expand All @@ -46,7 +45,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveEnlsipExt = "Enlsip"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
Expand All @@ -67,7 +65,6 @@ BenchmarkTools = "1.4"
CUDA = "5.2"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149.0"
Enlsip = "0.9"
Enzyme = "0.12"
ExplicitImports = "1.4.4"
FastBroadcast = "0.2.8"
Expand Down Expand Up @@ -119,7 +116,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
Expand All @@ -145,4 +141,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enlsip", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
1 change: 0 additions & 1 deletion docs/src/basics/nonlinear_solution.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ SciMLBase.NonlinearSolution

```@docs
SciMLBase.NLStats
NonlinearSolve.ImmutableNLStats
```

## Return Code
Expand Down
5 changes: 3 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
istril, istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A
InvPreconditioner, needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!, recursivefill!
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,
Expand Down
3 changes: 1 addition & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ abstract type AbstractNonlinearSolveLineSearchCache end

function reinit_cache!(
cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...)
cache.nf[] = 0
cache.p = p
end

Expand Down Expand Up @@ -235,7 +234,7 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
print(io, " "^(indent) * ")")
end
Expand Down
28 changes: 14 additions & 14 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ end
inv_workspace

# Counters
nf::Int
stats::NLStats
nsteps::Int
nresets::Int
max_resets::Int
Expand Down Expand Up @@ -131,7 +131,7 @@ function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip},
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.nresets = 0
cache.steps_since_last_reset = 0
Expand All @@ -151,8 +151,9 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::ApproximateJacobianSolveAlgorithm,
args...; alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing,
reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing,
args...; stats = empty_nlstats(), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -164,18 +165,17 @@ function SciMLBase.__init(
INV = store_inverse_jacobian(alg.update_rule)

linsolve = get_linear_solver(alg.descent)
initialization_cache = __internal_init(
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
stats, linsolve, maxiters, internalnorm)

abstol, reltol, termination_cache = init_termination_cache(
prob, abstol, reltol, fu, u, termination_condition)
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

J = initialization_cache(nothing)
inv_workspace, J = INV ? __safe_inv_workspace(J) : (nothing, J)
descent_cache = __internal_init(
prob, alg.descent, J, fu, u; abstol, reltol, internalnorm,
linsolve_kwargs, pre_inverted = Val(INV), timer)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, reltol,
internalnorm, linsolve_kwargs, pre_inverted = Val(INV), timer)
du = get_du(descent_cache)

reinit_rule_cache = __internal_init(alg.reinit_rule, J, fu, u, du)
Expand All @@ -192,28 +192,28 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

update_rule_cache = __internal_init(
prob, alg.update_rule, J, fu, u, du; internalnorm)
prob, alg.update_rule, J, fu, u, du; stats, internalnorm)

trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
uses_jacobian_inverse = Val(INV), kwargs...)

return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, initialization_cache,
descent_cache, linesearch_cache, trustregion_cache, update_rule_cache,
reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets,
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0,
termination_cache, trace, ReturnCode.Default, false, false, kwargs)
end
Expand All @@ -223,7 +223,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
new_jacobian = true
@static_timeit cache.timer "jacobian init/reinit" begin
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
if cache.nsteps == 0 # First Step is special ignore kwargs
J_init = __internal_solve!(
cache.initialization_cache, cache.fu, cache.u, Val(false))
if INV
Expand Down
23 changes: 12 additions & 11 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
trustregion_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -135,7 +135,7 @@ function __reinit_internal!(
end
cache.p = p

cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -153,9 +153,10 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, maxtime = nothing, termination_condition = nothing,
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, maxtime = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
(; f, u0, p) = prob
Expand All @@ -170,11 +171,11 @@ function SciMLBase.__init(
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

jac_cache = JacobianCache(
prob, alg, f, fu, u, p; autodiff = alg.jacobian_ad, linsolve,
prob, alg, f, fu, u, p; stats, autodiff = alg.jacobian_ad, linsolve,
jvp_autodiff = alg.forward_ad, vjp_autodiff = alg.reverse_ad)
J = jac_cache(nothing)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; abstol, reltol,
internalnorm, linsolve_kwargs, timer)
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol,
reltol, internalnorm, linsolve_kwargs, timer)
du = get_du(descent_cache)

if alg.trustregion !== missing && alg.linesearch !== missing
Expand All @@ -189,15 +190,15 @@ function SciMLBase.__init(
supports_trust_region(alg.descent) || error("Trust Region not supported by \
$(alg.descent).")
trustregion_cache = __internal_init(
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :TrustRegion
end

if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
GB = :LineSearch
end

Expand All @@ -206,7 +207,7 @@ function SciMLBase.__init(

return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer,
trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
20 changes: 8 additions & 12 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
alg::AbstractNonlinearSolveAlgorithm, args...; stats = empty_nlstats(), kwargs...)
cache = SciMLBase.__init(prob, alg, args...; stats, kwargs...)
return solve!(cache)
end

function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && get_nsteps(cache) < cache.maxiters
return !cache.force_stop && cache.nsteps < cache.maxiters
end

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
Expand All @@ -16,21 +16,16 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
cache.retcode = ifelse(
get_nsteps(cache) ≥ cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
cache.nsteps ≥ cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
end

update_from_termination_cache!(cache.termination_cache, cache)

update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
update_trace!(cache.trace, cache.nsteps, get_u(cache),
get_fu(cache), nothing, nothing, nothing; last = True)

return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, stats = __compile_stats(cache), cache.trace)
end

function __compile_stats(cache::AbstractNonlinearSolveCache)
return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
get_nsolve(cache), get_nsteps(cache))
cache.retcode, cache.stats, cache.trace)
end

"""
Expand All @@ -55,7 +50,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
__step!(cache, args...; kwargs...)
end

hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
cache.stats.nsteps += 1
cache.nsteps += 1

if timeit
cache.total_time += time() - time_start
Expand Down
16 changes: 8 additions & 8 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ concrete_jac(::GeneralizedDFSane) = nothing
linesearch_cache

# Counters
nf::Int
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
Expand Down Expand Up @@ -106,7 +106,7 @@ function __reinit_internal!(

reset!(cache.trace)
reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
cache.nf = 1
__reinit_internal!(cache.stats)
cache.nsteps = 0
cache.maxiters = maxiters
cache.maxtime = maxtime
Expand All @@ -116,9 +116,9 @@ end

@internal_caches GeneralizedDFSaneCache :linesearch_cache

function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
reltol = nothing, termination_condition = nothing,
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, termination_condition = nothing,
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand All @@ -130,8 +130,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
fu = evaluate_f(prob, u)
@bb fu_cache = copy(fu)

linesearch_cache = __internal_init(
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
stats, maxiters, internalnorm, kwargs...)

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
Expand All @@ -150,7 +150,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane

return GeneralizedDFSaneCache{isinplace(prob), maxtime !== nothing}(
fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min),
T(alg.σ_max), linesearch_cache, 0, 0, maxiters, maxtime,
T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime,
timer, 0.0, tc_cache, trace, ReturnCode.Default, false, kwargs)
end
end
Expand Down
Loading
Loading