Skip to content

Commit

Permalink
Support tracing for all and add an example
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 23, 2023
1 parent be7b44e commit 354e080
Show file tree
Hide file tree
Showing 16 changed files with 397 additions and 84 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -44,13 +45,14 @@ ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.136"
DiffEqBase = "6.141"
EnumX = "1"
Enzyme = "0.11"
FastBroadcast = "0.1.9, 0.2"
FastLevenbergMarquardt = "0.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LazyArrays = "1.8"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "<0.0.1, 1"
Expand All @@ -60,7 +62,7 @@ PrecompileTools = "1"
Printf = "<0.0.1, 1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "2.8.2"
SciMLBase = "2.9"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "<0.0.1, 1"
Expand Down
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pages = ["index.md",
"basics/solve.md",
"basics/NonlinearSolution.md",
"basics/TerminationCondition.md",
"basics/Logging.md",
"basics/FAQ.md"],
"Solver Summaries and Recommendations" => Any["solvers/NonlinearSystemSolvers.md",
"solvers/BracketingSolvers.md",
Expand Down
58 changes: 58 additions & 0 deletions docs/src/basics/Logging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# [Logging the Solve Process](@ logging_api)

All NonlinearSolve.jl native solvers allow storing and displaying the trace of the nonlinear
solve process. This is controlled by 3 keyword arguments to `solve`:

1. `show_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
displayed to the console. (Defaults to `Val(false)`)
2. `trace_level`: Needs to be one of Trace Objects: [`TraceMinimal`](@ref),
[`TraceWithJacobianConditionNumber`](@ref), or [`TraceAll`](@ref). This controls the
level of detail of the trace. (Defaults to `TraceMinimal()`)
3. `store_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
stored in the solution object. (Defaults to `Val(false)`)

## Example Usage

```@example tracing
using ModelingToolkit, NonlinearSolve
@variables x y z
@parameters σ ρ β
# Define a nonlinear system
eqs = [0 ~ σ * (y - x),
0 ~ x * (ρ - z) - y,
0 ~ x * y - β * z]
@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β])
u0 = [x => 1.0, y => 0.0, z => 0.0]
ps = [σ => 10.0 ρ => 26.0 β => 8 / 3]
prob = NonlinearProblem(ns, u0, ps)
solve(prob)
```

This produced the output, but it is hard to diagnose what is going on. We can turn on
the trace to see what is happening:

```@example tracing
solve(prob; show_trace = Val(true), trace_level = TraceAll(10))
```

You can also store the trace in the solution object:

```@example tracing
sol = solve(prob; trace_level = TraceAll(), store_trace = Val(true));
sol.trace
```

## API

```@docs
TraceMinimal
TraceWithJacobianConditionNumber
TraceAll
```
37 changes: 36 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase, LinearAlgebra, LinearSolve, Printf, SparseArrays, SparseDiffTools
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools
using FastBroadcast: @..
import ArrayInterface: restructure

Expand Down Expand Up @@ -50,6 +52,32 @@ abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
modifiers = String[]
if _getproperty(alg, Val(:ad)) !== nothing
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
end
if _getproperty(alg, Val(:linsolve)) !== nothing
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
end
if _getproperty(alg, Val(:linesearch)) !== nothing
ls = alg.linesearch
if ls isa LineSearch
ls.method !== nothing &&
push!(modifiers, "linesearch = $(nameof(typeof(ls.method)))()")
else
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
end
end
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
end
str = str * join(modifiers, ", ")
print(io, "$(str))")
return nothing
end

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
Expand Down Expand Up @@ -80,6 +108,10 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end

trace = _getproperty(cache, Val{:trace}())
if trace !== nothing
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
nothing, nothing; last = Val(true))
end

return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, trace)
Expand Down Expand Up @@ -165,4 +197,7 @@ export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
RelSafeBestTerminationMode, AbsSafeBestTerminationMode

# Tracing Functionality
export TraceAll, TraceMinimal, TraceWithJacobianConditionNumber

end # module
19 changes: 15 additions & 4 deletions src/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Sadly `Broyden` is taken up by SimpleNonlinearSolve.jl
"""
GeneralBroyden(; max_resets = 3, linesearch = LineSearch(), reset_tolerance = nothing)
GeneralBroyden(; max_resets = 3, linesearch = nothing, reset_tolerance = nothing)
An implementation of `Broyden` with reseting and line search.
Expand All @@ -21,7 +21,7 @@ An implementation of `Broyden` with reseting and line search.
linesearch
end

function GeneralBroyden(; max_resets = 3, linesearch = LineSearch(),
function GeneralBroyden(; max_resets = 3, linesearch = nothing,
reset_tolerance = nothing)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GeneralBroyden(max_resets, reset_tolerance, linesearch)
Expand Down Expand Up @@ -54,6 +54,7 @@ end
stats::NLStats
ls_cache
tc_cache
trace
end

get_fu(cache::GeneralBroydenCache) = cache.fu
Expand All @@ -66,19 +67,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu = evaluate_f(prob, u)
du = _mutable_zero(u)
J⁻¹ = __init_identity_jacobian(u, fu)
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
kwargs...)

return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache)
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

function perform_step!(cache::GeneralBroydenCache{true})
Expand All @@ -90,6 +94,9 @@ function perform_step!(cache::GeneralBroydenCache{true})
_axpy!(-α, du, u)
f(fu2, u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), J⁻¹, du, α)

check_and_update!(cache, fu2, u, u_prev)
cache.stats.nf += 1

Expand Down Expand Up @@ -131,6 +138,9 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.u = cache.u .- α * cache.du
cache.fu2 = f(cache.u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), cache.J⁻¹, cache.du, α)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.stats.nf += 1

Expand Down Expand Up @@ -173,6 +183,7 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.fu = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

Expand Down
8 changes: 4 additions & 4 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
fu; retcode = ReturnCode.Success, stats,
original = $(sol_syms[i]))
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
end
cache.current = $(i + 1)
end
Expand All @@ -103,7 +103,7 @@ end
u = cache.caches[idx].u

return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
fus[idx]; retcode, stats)
fus[idx]; retcode, stats, cache.caches[idx].trace)
end)

return Expr(:block, calls...)
Expand All @@ -125,7 +125,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
if SciMLBase.successful_retcode($(cur_sol))
return SciMLBase.build_solution(prob, alg, $(cur_sol).u,
$(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats,
original = $(cur_sol))
original = $(cur_sol), trace = $(cur_sol).trace)
end
end)
end
Expand All @@ -147,7 +147,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
$(sol_syms[i]).stats)
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
end
end)
end
Expand Down
11 changes: 10 additions & 1 deletion src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ end
prob
stats::NLStats
tc_cache
trace
end

get_fu(cache::DFSaneCache) = cache.fu
Expand All @@ -113,11 +114,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, uprev,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)

return DFSaneCache{iip}(alg, u, uprev, fu, fuprev, du, history, f_norm, f_norm_0, alg.M,
T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache)
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

function perform_step!(cache::DFSaneCache{true})
Expand Down Expand Up @@ -164,6 +166,9 @@ function perform_step!(cache::DFSaneCache{true})
f_norm = cache.internalnorm(cache.fu)^n_exp
end

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
cache.du, α₊)

check_and_update!(cache, cache.fu, cache.u, cache.uprev)

# Update spectral parameter
Expand Down Expand Up @@ -236,6 +241,9 @@ function perform_step!(cache::DFSaneCache{false})
f_norm = cache.internalnorm(cache.fu)^n_exp
end

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
cache.du, α₊)

check_and_update!(cache, cache.fu, cache.u, cache.uprev)

# Update spectral parameter
Expand Down Expand Up @@ -288,6 +296,7 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.u; p = cache.p,
T = eltype(cache.u)
cache.σ_n = T(cache.alg.σ_1)

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

Expand Down
15 changes: 12 additions & 3 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = nothing,
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
Expand Down Expand Up @@ -47,7 +47,7 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
linesearch = nothing, precs = DEFAULT_PRECS, vjp_autodiff = nothing,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
Expand Down Expand Up @@ -82,6 +82,7 @@ end
tc_cache_1
tc_cache_2
ls_cache
trace
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
Expand All @@ -108,11 +109,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
abstol, reltol, tc_cache_1 = init_termination_cache(abstol, reltol, fu1, u,
termination_condition)
_, _, tc_cache_2 = init_termination_cache(abstol, reltol, fu1, u, termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu1, J, du; kwargs...)

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), trace)
end

function perform_step!(cache::GaussNewtonCache{true})
Expand All @@ -137,6 +139,9 @@ function perform_step!(cache::GaussNewtonCache{true})
_axpy!(-α, du, u)
f(cache.fu_new, u, p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), J,
cache.du, α)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
if !cache.force_stop
cache.fu1 .= cache.fu_new .- cache.fu1
Expand Down Expand Up @@ -179,6 +184,9 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
cache.du, α)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
if !cache.force_stop
cache.fu1 = cache.fu_new .- cache.fu1
Expand Down Expand Up @@ -207,6 +215,7 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
cache.fu1 = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache_1 = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
termination_condition)
_, _, tc_cache_2 = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
Expand Down
Loading

0 comments on commit 354e080

Please sign in to comment.