StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 12:59:22
File source code
Line Exclusive Inclusive Code
1 module NonlinearSolve
2
3 if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
4 @eval Base.Experimental.@max_methods 1
5 end
6
7 import Reexport: @reexport
8 import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
9
10 @recompile_invalidations begin
11 using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
12 SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays
13
14 import ADTypes: AbstractFiniteDifferencesMode
15 import ArrayInterface: undefmatrix, restructure, can_setindex,
16 matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
17 import ConcreteStructs: @concrete
18 import EnumX: @enumx
19 import FastBroadcast: @..
20 import FiniteDiff
21 import ForwardDiff
22 import ForwardDiff: Dual
23 import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
24 import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex
25 import RecursiveArrayTools: ArrayPartition,
26 AbstractVectorOfArray, recursivecopy!, recursivefill!
27 import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
28 import SciMLOperators: FunctionOperator
29 import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
30 import UnPack: @unpack
31 end
32
33 @reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
34 import DiffEqBase: AbstractNonlinearTerminationMode,
35 AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
36 NonlinearSafeTerminationReturnCode, get_termination_mode
37
38 const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
39 ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
40
41 # Type-Inference Friendly Check for Extension Loading
42 is_extension_loaded(::Val) = false
43
44 abstract type AbstractNonlinearSolveLineSearchAlgorithm end
45
46 abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
47 abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
48
49 abstract type AbstractNonlinearSolveCache{iip} end
50
51 isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
52
53 function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
54 p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
55 maxiters = cache.maxiters, alias_u0 = false, termination_condition = missing,
56 kwargs...) where {iip}
57 cache.p = p
58 if iip
59 recursivecopy!(get_u(cache), u0)
60 cache.f(get_fu(cache), get_u(cache), p)
61 else
62 cache.u = __maybe_unaliased(u0, alias_u0)
63 set_fu!(cache, cache.f(cache.u, p))
64 end
65
66 reset!(cache.trace)
67
68 # Some algorithms store multiple termination caches
69 if hasfield(typeof(cache), :tc_cache)
70 # TODO: We need an efficient way to reset this upstream
71 tc = termination_condition === missing ? get_termination_mode(cache.tc_cache) :
72 termination_condition
73 abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
74 get_u(cache), tc)
75 cache.tc_cache = tc_cache
76 end
77
78 if hasfield(typeof(cache), :ls_cache)
79 # TODO: A more efficient way to do this
80 cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
81 get_u(cache), p, get_fu(cache), Val(iip))
82 end
83
84 hasfield(typeof(cache), :uf) && cache.uf !== nothing && (cache.uf.p = p)
85
86 cache.abstol = abstol
87 cache.reltol = reltol
88 cache.maxiters = maxiters
89 cache.stats.nf = 1
90 cache.stats.nsteps = 1
91 cache.force_stop = false
92 cache.retcode = ReturnCode.Default
93
94 __reinit_internal!(cache; u0, p, abstol, reltol, maxiters, alias_u0,
95 termination_condition, kwargs...)
96
97 return cache
98 end
99
100 __reinit_internal!(::AbstractNonlinearSolveCache; kwargs...) = nothing
101
102 function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
103 str = "$(nameof(typeof(alg)))("
104 modifiers = String[]
105 if __getproperty(alg, Val(:ad)) !== nothing
106 push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
107 end
108 if __getproperty(alg, Val(:linsolve)) !== nothing
109 push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
110 end
111 if __getproperty(alg, Val(:linesearch)) !== nothing
112 ls = alg.linesearch
113 if ls isa LineSearch
114 ls.method !== nothing &&
115 push!(modifiers, "linesearch = $(nameof(typeof(ls.method)))()")
116 else
117 push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
118 end
119 end
120 append!(modifiers, __alg_print_modifiers(alg))
121 if __getproperty(alg, Val(:radius_update_scheme)) !== nothing
122 push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
123 end
124 str = str * join(modifiers, ", ")
125 print(io, "$(str))")
126 return nothing
127 end
128
129 __alg_print_modifiers(_) = String[]
130
131 284 (99 %)
568 (197 %) samples spent in __solve
284 (50 %) (incl.) when called from #solve_call#34 line 606
284 (50 %) (incl.) when called from __solve line 131
284 (100 %) samples spent calling #__solve#3
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
132 alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
133 119 (41 %)
119 (100 %) samples spent calling init
cache = init(prob, alg, args...; kwargs...)
134 165 (57 %)
165 (100 %) samples spent calling solve!
return solve!(cache)
135 end
136
137 function not_terminated(cache::AbstractNonlinearSolveCache)
138 return !cache.force_stop && cache.stats.nsteps < cache.maxiters
139 end
140
141 get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
142 set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
143 get_u(cache::AbstractNonlinearSolveCache) = cache.u
144 SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)
145
146
165 (57 %) samples spent in solve!
165 (100 %) (incl.) when called from #__solve#3 line 134
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
147 while not_terminated(cache)
148 165 (57 %)
165 (100 %) samples spent calling perform_step!
perform_step!(cache)
149 cache.stats.nsteps += 1
150 end
151
152 # The solver might have set a different `retcode`
153 if cache.retcode == ReturnCode.Default
154 if cache.stats.nsteps == cache.maxiters
155 cache.retcode = ReturnCode.MaxIters
156 else
157 cache.retcode = ReturnCode.Success
158 end
159 end
160
161 trace = __getproperty(cache, Val{:trace}())
162 if trace !== nothing
163 update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
164 nothing, nothing; last = Val(true))
165 end
166
167 return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
168 cache.retcode, cache.stats, trace)
169 end
170
171 include("utils.jl")
172 include("trace.jl")
173 include("extension_algs.jl")
174 include("linesearch.jl")
175 include("raphson.jl")
176 include("trustRegion.jl")
177 include("levenberg.jl")
178 include("gaussnewton.jl")
179 include("dfsane.jl")
180 include("pseudotransient.jl")
181 include("broyden.jl")
182 include("klement.jl")
183 include("lbroyden.jl")
184 include("jacobian.jl")
185 include("ad.jl")
186 include("default.jl")
187
188 @setup_workload begin
189 nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
190 (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
191 (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
192 probs_nls = NonlinearProblem[]
193 for T in (Float32, Float64), (fn, u0) in nlfuncs
194 push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
195 end
196
197 nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
198 Broyden(), Klement(), DFSane(), nothing)
199
200 probs_nlls = NonlinearLeastSquaresProblem[]
201 nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
202 (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
203 (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
204 resid_prototype = zeros(1)), [0.1, 0.0]),
205 (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
206 resid_prototype = zeros(4)), [0.1, 0.1]))
207 for (fn, u0) in nlfuncs
208 push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
209 end
210 nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
211 (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
212 Float32[0.1, 0.1]),
213 (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
214 resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
215 (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
216 resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
217 for (fn, u0) in nlfuncs
218 push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
219 end
220
221 nlls_algs = (LevenbergMarquardt(), GaussNewton(),
222 LevenbergMarquardt(; linsolve = LUFactorization()),
223 GaussNewton(; linsolve = LUFactorization()))
224
225 @compile_workload begin
226 for prob in probs_nls, alg in nls_algs
227 solve(prob, alg, abstol = 1e-2)
228 end
229 for prob in probs_nlls, alg in nlls_algs
230 solve(prob, alg, abstol = 1e-2)
231 end
232 end
233 end
234
235 export RadiusUpdateSchemes
236
237 export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
238 Broyden, Klement, LimitedMemoryBroyden
239 export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL
240 export NonlinearSolvePolyAlgorithm,
241 RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
242
243 export LineSearch, LiFukushimaLineSearch
244
245 # Export the termination conditions from DiffEqBase
246 export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
247 NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode,
248 AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
249 RelSafeBestTerminationMode, AbsSafeBestTerminationMode
250
251 # Tracing Functionality
252 export TraceAll, TraceMinimal, TraceWithJacobianConditionNumber
253
254 end # module