StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 12:59:22
File source code
Line Exclusive Inclusive Code
1 const DEFAULT_NORM = DiffEqBase.NONLINEARSOLVE_DEFAULT_NORM
2
3 @concrete mutable struct FakeLinearSolveJLCache
4 A
5 b
6 end
7
8 @concrete struct FakeLinearSolveJLResult
9 cache
10 u
11 end
12
13 # Ignores NaN
14 function __findmin(f, x)
15 return findmin(x) do xᵢ
16 fx = f(xᵢ)
17 return isnan(fx) ? Inf : fx
18 end
19 end
20
21 struct NonlinearSolveTag end
22
23 function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:NonlinearSolveTag, <:T}}, f::F,
24 x::AbstractArray{T}) where {T, F}
25 return true
26 end
27
28 """
29 value_derivative(f, x)
30
31 Compute `f(x), d/dx f(x)` in the most efficient way.
32 """
33 function value_derivative(f::F, x::R) where {F, R}
34 T = typeof(ForwardDiff.Tag(f, R))
35 out = f(ForwardDiff.Dual{T}(x, one(x)))
36 ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
37 end
38
39 @inline value(x) = x
40 @inline value(x::Dual) = ForwardDiff.value(x)
41 @inline value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
42
43 @inline _vec(v) = vec(v)
44 @inline _vec(v::Number) = v
45 @inline _vec(v::AbstractVector) = v
46
47 @inline _restructure(y, x) = restructure(y, x)
48 @inline _restructure(y::Number, x::Number) = x
49
50 DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
51
52 function dolinsolve(cache, precs::P, linsolve::FakeLinearSolveJLCache; A = nothing,
53 linu = nothing, b = nothing, du = nothing, p = nothing, weight = nothing,
54 cachedata = nothing, reltol = nothing, reuse_A_if_factorization = false) where {P}
55 # Update Statistics
56 cache.stats.nsolve += 1
57 cache.stats.nfactors += !(A isa Number)
58
59 A !== nothing && (linsolve.A = A)
60 b !== nothing && (linsolve.b = b)
61 linres = linsolve.A \ linsolve.b
62 return FakeLinearSolveJLResult(linsolve, linres)
63 end
64
65 69 (24 %)
69 (24 %) samples spent in dolinsolve
69 (100 %) (incl.) when called from perform_step! line 146
69 (100 %) samples spent calling #dolinsolve#7
function dolinsolve(cache, precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
66 du = nothing, p = nothing, weight = nothing, cachedata = nothing, reltol = nothing,
67 reuse_A_if_factorization = false) where {P}
68 # Update Statistics
69 cache.stats.nsolve += 1
70 cache.stats.nfactors += 1
71
72 # Some Algorithms would reuse factorization but it causes the cache to not reset in
73 # certain cases
74 if A !== nothing
75 alg = __getproperty(linsolve, Val(:alg))
76 if alg !== nothing && ((alg isa LinearSolve.AbstractFactorization) ||
77 (alg isa LinearSolve.DefaultLinearSolver && !(alg ==
78 LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES))))
79 # Factorization Algorithm
80 if reuse_A_if_factorization
81 cache.stats.nfactors -= 1
82 else
83 linsolve.A = A
84 end
85 else
86 linsolve.A = A
87 end
88 else
89 cache.stats.nfactors -= 1
90 end
91 b !== nothing && (linsolve.b = b)
92 linu !== nothing && (linsolve.u = linu)
93
94 Plprev = linsolve.Pl isa ComposePreconditioner ? linsolve.Pl.outer : linsolve.Pl
95 Prprev = linsolve.Pr isa ComposePreconditioner ? linsolve.Pr.outer : linsolve.Pr
96
97 _Pl, _Pr = precs(linsolve.A, du, linu, p, nothing, A !== nothing, Plprev, Prprev,
98 cachedata)
99 if (_Pl !== nothing || _Pr !== nothing)
100 _weight = weight === nothing ?
101 (linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) :
102 weight
103 Pl, Pr = wrapprecs(_Pl, _Pr, _weight)
104 linsolve.Pl = Pl
105 linsolve.Pr = Pr
106 end
107
108 69 (24 %)
69 (24 %) samples spent in #dolinsolve#7
69 (100 %) (incl.) when called from dolinsolve line 65
69 (100 %) samples spent calling solve!
linres = reltol === nothing ? solve!(linsolve) : solve!(linsolve; reltol)
109
110 return linres
111 end
112
113 function wrapprecs(_Pl, _Pr, weight)
114 if _Pl !== nothing
115 Pl = ComposePreconditioner(InvPreconditioner(Diagonal(_vec(weight))), _Pl)
116 else
117 Pl = InvPreconditioner(Diagonal(_vec(weight)))
118 end
119
120 if _Pr !== nothing
121 Pr = ComposePreconditioner(Diagonal(_vec(weight)), _Pr)
122 else
123 Pr = Diagonal(_vec(weight))
124 end
125
126 return Pl, Pr
127 end
128
129 concrete_jac(_) = nothing
130 concrete_jac(::AbstractNewtonAlgorithm{CJ}) where {CJ} = CJ
131
132 _mutable_zero(x) = zero(x)
133 _mutable_zero(x::SArray) = MArray(x)
134
135 _mutable(x) = x
136 _mutable(x::SArray) = MArray(x)
137
138 # __maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x)
139 # The shadow allocated for Enzyme needs to be mutable
140 __maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
141 __maybe_mutable(x, _) = x
142
143 # Helper function to get value of `f(u, p)`
144 function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
145 NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip}
146 @unpack f, u0, p = prob
147 if iip
148 fu = f.resid_prototype === nothing ? similar(u) : f.resid_prototype
149 f(fu, u, p)
150 else
151 fu = f(u, p)
152 end
153 return fu
154 end
155
156 function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip}
157 if iip
158 f(fu, u, p)
159 return fu
160 else
161 return f(u, p)
162 end
163 end
164
165 function evaluate_f(cache::AbstractNonlinearSolveCache, u, p,
166 fu_sym::Val{FUSYM} = Val(nothing)) where {FUSYM}
167 cache.stats.nf += 1
168 if FUSYM === nothing
169 if isinplace(cache)
170 cache.prob.f(get_fu(cache), u, p)
171 else
172 set_fu!(cache, cache.prob.f(u, p))
173 end
174 else
175 if isinplace(cache)
176 cache.prob.f(__getproperty(cache, fu_sym), u, p)
177 else
178 setproperty!(cache, FUSYM, cache.prob.f(u, p))
179 end
180 end
181 return nothing
182 end
183
184 # Concretize Algorithms
185 function get_concrete_algorithm(alg, prob)
186 !hasfield(typeof(alg), :ad) && return alg
187 alg.ad isa ADTypes.AbstractADType && return alg
188
189 # Figure out the default AD
190 # Now that we have handed trivial cases, we can allow extending this function
191 # for specific algorithms
192 return __get_concrete_algorithm(alg, prob)
193 end
194
195 function __get_concrete_algorithm(alg, prob)
196 @unpack sparsity, jac_prototype = prob.f
197 use_sparse_ad = sparsity !== nothing || jac_prototype !== nothing
198 ad = if !ForwardDiff.can_dual(eltype(prob.u0))
199 # Use Finite Differencing
200 use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
201 else
202 (use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
203 tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
204 end
205 return set_ad(alg, ad)
206 end
207
208 function init_termination_cache(abstol, reltol, du, u, ::Nothing)
209 return init_termination_cache(abstol, reltol, du, u, AbsSafeBestTerminationMode())
210 end
211 function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
212 tc_cache = init(du, u, tc; abstol, reltol)
213 return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
214 end
215
216 function check_and_update!(cache, fu, u, uprev)
217 return check_and_update!(cache.tc_cache, cache, fu, u, uprev)
218 end
219 function check_and_update!(tc_cache, cache, fu, u, uprev)
220 return check_and_update!(tc_cache, cache, fu, u, uprev,
221 DiffEqBase.get_termination_mode(tc_cache))
222 end
223 function check_and_update!(tc_cache, cache, fu, u, uprev,
224 mode::AbstractNonlinearTerminationMode)
225 if tc_cache(fu, u, uprev)
226 # Just a sanity measure!
227 if isinplace(cache)
228 cache.prob.f(get_fu(cache), u, cache.prob.p)
229 else
230 set_fu!(cache, cache.prob.f(u, cache.prob.p))
231 end
232 cache.force_stop = true
233 end
234 end
235 function check_and_update!(tc_cache, cache, fu, u, uprev,
236 mode::AbstractSafeNonlinearTerminationMode)
237 if tc_cache(fu, u, uprev)
238 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
239 cache.retcode = ReturnCode.Success
240 end
241 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
242 cache.retcode = ReturnCode.ConvergenceFailure
243 end
244 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
245 cache.retcode = ReturnCode.Unstable
246 end
247 # Just a sanity measure!
248 if isinplace(cache)
249 cache.prob.f(get_fu(cache), u, cache.prob.p)
250 else
251 set_fu!(cache, cache.prob.f(u, cache.prob.p))
252 end
253 cache.force_stop = true
254 end
255 end
256 function check_and_update!(tc_cache, cache, fu, u, uprev,
257 mode::AbstractSafeBestNonlinearTerminationMode)
258 if tc_cache(fu, u, uprev)
259 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
260 cache.retcode = ReturnCode.Success
261 end
262 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
263 cache.retcode = ReturnCode.ConvergenceFailure
264 end
265 if tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
266 cache.retcode = ReturnCode.Unstable
267 end
268 if isinplace(cache)
269 copyto!(get_u(cache), tc_cache.u)
270 cache.prob.f(get_fu(cache), get_u(cache), cache.prob.p)
271 else
272 set_u!(cache, tc_cache.u)
273 set_fu!(cache, cache.prob.f(get_u(cache), cache.prob.p))
274 end
275 cache.force_stop = true
276 end
277 end
278
279 @inline __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
280 @inline @views function __init_identity_jacobian(u, fu, α = true)
281 J = similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
282 fill!(J, zero(eltype(J)))
283 if fast_scalar_indexing(J)
284 @inbounds for i in axes(J, 1)
285 J[i, i] = α
286 end
287 else
288 J[diagind(J)] .= α
289 end
290 return J
291 end
292 @inline function __init_identity_jacobian(u::StaticArray, fu::StaticArray, α = true)
293 T = promote_type(eltype(fu), eltype(u))
294 return MArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I * α)
295 end
296 @inline function __init_identity_jacobian(u::SArray, fu::SArray, α = true)
297 T = promote_type(eltype(fu), eltype(u))
298 return SArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I * α)
299 end
300
301 @inline __reinit_identity_jacobian!!(J::Number, α = true) = oftype(J, α)
302 @inline __reinit_identity_jacobian!!(J::AbstractVector, α = true) = fill!(J, α)
303 @inline @views function __reinit_identity_jacobian!!(J::AbstractMatrix, α = true)
304 fill!(J, zero(eltype(J)))
305 if fast_scalar_indexing(J)
306 @inbounds for i in axes(J, 1)
307 J[i, i] = α
308 end
309 else
310 J[diagind(J)] .= α
311 end
312 return J
313 end
314 @inline function __reinit_identity_jacobian!!(J::SVector, α = true)
315 return ones(SArray{Tuple{Size(J)[1]}, eltype(J)}) .* α
316 end
317 @inline function __reinit_identity_jacobian!!(J::SMatrix, α = true)
318 S = Size(J)
319 return SArray{Tuple{S[1], S[2]}, eltype(J)}(I) .* α
320 end
321
322 function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
323 ::Val{threshold}) where {S1, S2, T1, T2, threshold}
324 T = promote_type(T1, T2)
325 fuSize, uSize = Size(fu), Size(u)
326 Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
327 U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
328 return U, Vᵀ
329 end
330 function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
331 Vᵀ = similar(u, threshold, length(u))
332 U = similar(u, length(fu), threshold)
333 return U, Vᵀ
334 end
335
336 @inline __is_ill_conditioned(x::Number) = iszero(x)
337 @inline __is_ill_conditioned(x::AbstractMatrix) = cond(x) ≥
338 inv(eps(real(eltype(x)))^(1 // 2))
339 @inline __is_ill_conditioned(x::AbstractVector) = any(iszero, x)
340 @inline __is_ill_conditioned(x) = false
341
342 # Safe getproperty
343 @generated function __getproperty(s::S, ::Val{X}) where {S, X}
344 hasfield(S, X) && return :(s.$X)
345 return :(nothing)
346 end
347
348 # Non-square matrix
349 @inline __needs_square_A(_, ::Number) = true
350 @inline __needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
351
352 # Define special concatenation for certain Array combinations
353 @inline _vcat(x, y) = vcat(x, y)
354
355 # LazyArrays for tracing
356 __zero(x::AbstractArray) = zero(x)
357 __zero(x) = x
358 LazyArrays.applied_eltype(::typeof(__zero), x) = eltype(x)
359 LazyArrays.applied_ndims(::typeof(__zero), x) = ndims(x)
360 LazyArrays.applied_size(::typeof(__zero), x) = size(x)
361 LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
362
363 # Safe Inverse: Try to use `inv` but if lu fails use `pinv`
364 @inline __safe_inv(A::Number) = pinv(A)
365 @inline __safe_inv(A::AbstractMatrix) = pinv(A)
366 @inline __safe_inv(A::AbstractVector) = __safe_inv(Diagonal(A)).diag
367 @inline __safe_inv(A::ApplyArray) = __safe_inv(A.f(A.args...))
368 @inline function __safe_inv(A::StridedMatrix{T}) where {T}
369 LinearAlgebra.checksquare(A)
370 if istriu(A)
371 A_ = UpperTriangular(A)
372 issingular = any(iszero, @view(A_[diagind(A_)]))
373 !issingular && return triu!(parent(inv(A_)))
374 elseif istril(A)
375 A_ = LowerTriangular(A)
376 issingular = any(iszero, @view(A_[diagind(A_)]))
377 !issingular && return tril!(parent(inv(A_)))
378 else
379 F = lu(A; check = false)
380 if issuccess(F)
381 Ai = LinearAlgebra.inv!(F)
382 return convert(typeof(parent(Ai)), Ai)
383 end
384 end
385 return pinv(A)
386 end
387 @inline __safe_inv(A::SparseMatrixCSC) = __safe_inv(Matrix(A))
388
389 LazyArrays.applied_eltype(::typeof(__safe_inv), x) = eltype(x)
390 LazyArrays.applied_ndims(::typeof(__safe_inv), x) = ndims(x)
391 LazyArrays.applied_size(::typeof(__safe_inv), x) = size(x)
392 LazyArrays.applied_axes(::typeof(__safe_inv), x) = axes(x)
393
394 # SparseAD --> NonSparseAD
395 @inline __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
396 @inline __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
397 @inline __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
398 @inline __get_nonsparse_ad(ad) = ad
399
400 # Use Symmetric Matrices if known to be efficient
401 @inline __maybe_symmetric(x) = Symmetric(x)
402 @inline __maybe_symmetric(x::Number) = x
403 ## LinearSolve with `nothing` doesn't dispatch correctly here
404 @inline __maybe_symmetric(x::StaticArray) = x
405 @inline __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
406 @inline __maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
407
408 # Unalias
409 @inline __maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x
410 @inline function __maybe_unaliased(x::AbstractArray, alias::Bool)
411 # Spend time coping iff we will mutate the array
412 (alias || !can_setindex(typeof(x))) && return x
413 return deepcopy(x)
414 end
415
416 # Init ones
417 @inline function __init_ones(x)
418 w = similar(x)
419 recursivefill!(w, true)
420 return w
421 end
422 @inline __init_ones(x::StaticArray) = ones(typeof(x))
423
424 # Diagonal of type `u`
425 __init_diagonal(u::Number, v) = oftype(u, v)
426 function __init_diagonal(u::SArray, v)
427 u_ = vec(u)
428 return Diagonal(ones(typeof(u_)) * v)
429 end
430 function __init_diagonal(u, v)
431 d = similar(vec(u))
432 d .= v
433 return Diagonal(d)
434 end
435
436 # Reduce sum
437 function __sum_JᵀJ!!(y, J)
438 if setindex_trait(y) === CanSetindex()
439 sum!(abs2, y, J')
440 return y
441 else
442 return sum(abs2, J'; dims = 1)
443 end
444 end
445
446 # Alpha for Initial Jacobian Guess
447 # The values are somewhat different from SciPy, these were tuned to the 23 test problems
448 @inline function __initial_inv_alpha(α::Number, u, fu, norm::F) where {F}
449 return convert(promote_type(eltype(u), eltype(fu)), inv(α))
450 end
451 @inline function __initial_inv_alpha(::Nothing, u, fu, norm::F) where {F}
452 norm_fu = norm(fu)
453 return ifelse(norm_fu ≥ 1e-5, max(norm(u), true) / (2 * norm_fu),
454 convert(promote_type(eltype(u), eltype(fu)), true))
455 end
456 @inline __initial_inv_alpha(inv_α, α::Number, u, fu, norm::F) where {F} = inv_α
457 @inline function __initial_inv_alpha(inv_α, α::Nothing, u, fu, norm::F) where {F}
458 return __initial_inv_alpha(α, u, fu, norm)
459 end
460
461 @inline function __initial_alpha(α::Number, u, fu, norm::F) where {F}
462 return convert(promote_type(eltype(u), eltype(fu)), α)
463 end
464 @inline function __initial_alpha(::Nothing, u, fu, norm::F) where {F}
465 norm_fu = norm(fu)
466 return ifelse(1e-5 ≤ norm_fu ≤ 1e5, max(norm(u), true) / (2 * norm_fu),
467 convert(promote_type(eltype(u), eltype(fu)), true))
468 end
469 @inline __initial_alpha(α_initial, α::Number, u, fu, norm::F) where {F} = α_initial
470 @inline function __initial_alpha(α_initial, α::Nothing, u, fu, norm::F) where {F}
471 return __initial_alpha(α, u, fu, norm)
472 end
473
474 # Diagonal
475 @inline function __get_diagonal!!(J::AbstractVector, J_full::AbstractMatrix)
476 if can_setindex(J)
477 if fast_scalar_indexing(J)
478 @inbounds for i in eachindex(J)
479 J[i] = J_full[i, i]
480 end
481 else
482 J .= view(J_full, diagind(J_full))
483 end
484 else
485 J = __diag(J_full)
486 end
487 return J
488 end
489 @inline function __get_diagonal!!(J::AbstractArray, J_full::AbstractMatrix)
490 return _restructure(J, __get_diagonal!!(_vec(J), J_full))
491 end
492 @inline __get_diagonal!!(J::Number, J_full::Number) = J_full
493
494 @inline __diag(x::AbstractMatrix) = diag(x)
495 @inline __diag(x::AbstractVector) = x
496 @inline __diag(x::Number) = x