StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 12:59:22
File source code
Line Exclusive Inclusive Code
1 needs_concrete_A(alg::DefaultLinearSolver) = true
2 mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3 T13, T14, T15, T16, T17, T18, T19}
4 LUFactorization::T1
5 QRFactorization::T2
6 DiagonalFactorization::T3
7 DirectLdiv!::T4
8 SparspakFactorization::T5
9 KLUFactorization::T6
10 UMFPACKFactorization::T7
11 KrylovJL_GMRES::T8
12 GenericLUFactorization::T9
13 RFLUFactorization::T10
14 LDLtFactorization::T11
15 BunchKaufmanFactorization::T12
16 CHOLMODFactorization::T13
17 SVDFactorization::T14
18 CholeskyFactorization::T15
19 NormalCholeskyFactorization::T16
20 AppleAccelerateLUFactorization::T17
21 MKLLUFactorization::T18
22 QRFactorizationPivoted::T19
23 end
24
25 # Legacy fallback
26 # For SciML algorithms already using `defaultalg`, all assume square matrix.
27 defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
28
29 function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
30 assump::OperatorAssumptions{Bool})
31 defaultalg(A.A, b, assump)
32 end
33
34 function defaultalg(A, b, assump::OperatorAssumptions{Nothing})
35 issq = issquare(A)
36 defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
37 end
38
39 function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2}
40 if S1 == S2
41 return LUFactorization()
42 else
43 return SVDFactorization() # QR(...) \ b is not defined currently
44 end
45 end
46
47 function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
48 if assump.issq
49 DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
50 else
51 DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
52 end
53 end
54
55 function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{Bool})
56 DefaultLinearSolver(DefaultAlgorithmChoice.LDLtFactorization)
57 end
58 function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{Bool})
59 DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
60 end
61 function defaultalg(A::Factorization, b, ::OperatorAssumptions{Bool})
62 DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
63 end
64 function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Bool})
65 DefaultLinearSolver(DefaultAlgorithmChoice.DiagonalFactorization)
66 end
67
68 function defaultalg(A::Hermitian, b, ::OperatorAssumptions{Bool})
69 DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
70 end
71
72 function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bool})
73 DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization)
74 end
75
76 function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
77 DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
78 end
79
80 function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
81 assump::OperatorAssumptions{Bool}) where {Tv, Ti}
82 if assump.issq
83 DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
84 else
85 error("Generic number sparse factorization for non-square is not currently handled")
86 end
87 end
88
89 @static if INCLUDE_SPARSE
90 function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
91 assump::OperatorAssumptions{Bool}) where {Ti}
92 if assump.issq
93 if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
94 DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
95 else
96 DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization)
97 end
98 else
99 DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
100 end
101 end
102 end
103
104 function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})
105 if assump.condition === OperatorCondition.IllConditioned || !assump.issq
106 DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
107 else
108 DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
109 end
110 end
111
112 # A === nothing case
113 function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool})
114 if assump.condition === OperatorCondition.IllConditioned || !assump.issq
115 DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
116 else
117 DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
118 end
119 end
120
121 # Ambiguity handling
122 function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray,
123 assump::OperatorAssumptions{Bool})
124 if assump.condition === OperatorCondition.IllConditioned || !assump.issq
125 DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
126 else
127 DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
128 end
129 end
130
131 function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
132 assump::OperatorAssumptions{Bool})
133 if has_ldiv!(A)
134 return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
135 elseif !assump.issq
136 m, n = size(A)
137 if m < n
138 DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_CRAIGMR)
139 else
140 DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_LSMR)
141 end
142 else
143 DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_GMRES)
144 end
145 end
146
147 # Allows A === nothing as a stand-in for dense matrix
148 function defaultalg(A, b, assump::OperatorAssumptions{Bool})
149 alg = if assump.issq
150 # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
151 # it makes sense according to the benchmarks, which is dependent on
152 # whether MKL or OpenBLAS is being used
153 if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
154 if (A === nothing ||
155 eltype(A) <: BLASELTYPES) &&
156 ArrayInterface.can_setindex(b) &&
157 (__conditioning(assump) === OperatorCondition.IllConditioned ||
158 __conditioning(assump) === OperatorCondition.WellConditioned)
159 if length(b) <= 10
160 DefaultAlgorithmChoice.GenericLUFactorization
161 elseif appleaccelerate_isavailable()
162 DefaultAlgorithmChoice.AppleAccelerateLUFactorization
163 elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
164 (usemkl && length(b) <= 200)) &&
165 (A === nothing ? eltype(b) <: Union{Float32, Float64} :
166 eltype(A) <: Union{Float32, Float64})
167 DefaultAlgorithmChoice.RFLUFactorization
168 #elseif A === nothing || A isa Matrix
169 # alg = FastLUFactorization()
170 elseif usemkl
171 DefaultAlgorithmChoice.MKLLUFactorization
172 else
173 DefaultAlgorithmChoice.LUFactorization
174 end
175 elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
176 DefaultAlgorithmChoice.QRFactorization
177 elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
178 DefaultAlgorithmChoice.SVDFactorization
179 elseif usemkl && (A === nothing ? eltype(b) <: BLASELTYPES :
180 eltype(A) <: BLASELTYPES)
181 DefaultAlgorithmChoice.MKLLUFactorization
182 else
183 DefaultAlgorithmChoice.LUFactorization
184 end
185
186 # This catches the cases where a factorization overload could exist
187 # For example, BlockBandedMatrix
188 elseif A !== nothing && ArrayInterface.isstructured(A)
189 error("Special factorization not handled in current default algorithm")
190
191 # Not factorizable operator, default to only using A*x
192 else
193 DefaultAlgorithmChoice.KrylovJL_GMRES
194 end
195 elseif assump.condition === OperatorCondition.WellConditioned
196 DefaultAlgorithmChoice.NormalCholeskyFactorization
197 elseif assump.condition === OperatorCondition.IllConditioned
198 if is_underdetermined(A)
199 # Underdetermined
200 DefaultAlgorithmChoice.QRFactorizationPivoted
201 else
202 DefaultAlgorithmChoice.QRFactorization
203 end
204 elseif assump.condition === OperatorCondition.VeryIllConditioned
205 if is_underdetermined(A)
206 # Underdetermined
207 DefaultAlgorithmChoice.QRFactorizationPivoted
208 else
209 DefaultAlgorithmChoice.QRFactorization
210 end
211 elseif assump.condition === OperatorCondition.SuperIllConditioned
212 DefaultAlgorithmChoice.SVDFactorization
213 else
214 error("Special factorization not handled in current default algorithm")
215 end
216 DefaultLinearSolver(alg)
217 end
218
219 function algchoice_to_alg(alg::Symbol)
220 if alg === :SVDFactorization
221 SVDFactorization(false, LinearAlgebra.QRIteration())
222 elseif alg === :LDLtFactorization
223 LDLtFactorization()
224 elseif alg === :LUFactorization
225 LUFactorization()
226 elseif alg === :MKLLUFactorization
227 MKLLUFactorization()
228 elseif alg === :QRFactorization
229 QRFactorization()
230 elseif alg === :DiagonalFactorization
231 DiagonalFactorization()
232 elseif alg === :DirectLdiv!
233 DirectLdiv!()
234 elseif alg === :SparspakFactorization
235 SparspakFactorization()
236 elseif alg === :KLUFactorization
237 KLUFactorization()
238 elseif alg === :UMFPACKFactorization
239 UMFPACKFactorization()
240 elseif alg === :KrylovJL_GMRES
241 KrylovJL_GMRES()
242 elseif alg === :GenericLUFactorization
243 GenericLUFactorization()
244 elseif alg === :RFLUFactorization
245 RFLUFactorization()
246 elseif alg === :BunchKaufmanFactorization
247 BunchKaufmanFactorization()
248 elseif alg === :CHOLMODFactorization
249 CHOLMODFactorization()
250 elseif alg === :CholeskyFactorization
251 CholeskyFactorization()
252 elseif alg === :NormalCholeskyFactorization
253 NormalCholeskyFactorization()
254 elseif alg === :AppleAccelerateLUFactorization
255 AppleAccelerateLUFactorization()
256 elseif alg === :QRFactorizationPivoted
257 @static if VERSION ≥ v"1.7beta"
258 QRFactorization(ColumnNorm())
259 else
260 QRFactorization(Val(true))
261 end
262 else
263 error("Algorithm choice symbol $alg not allowed in the default")
264 end
265 end
266
267 ## Catch high level interface
268
269 119 (41 %)
238 (83 %) samples spent in init
119 (50 %) (incl.) when called from init line 269
119 (50 %) (incl.) when called from #linsolve_caches#92 line 181
119 (100 %) samples spent calling #init#82
function SciMLBase.init(prob::LinearProblem, alg::Nothing,
270 args...;
271 assumptions = OperatorAssumptions(issquare(prob.A)),
272 kwargs...)
273 119 (41 %)
119 (100 %) samples spent calling init
SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...)
274 end
275
276 function SciMLBase.solve!(cache::LinearCache, alg::Nothing,
277 args...; assump::OperatorAssumptions = OperatorAssumptions(),
278 kwargs...)
279 @unpack A, b = cache
280 SciMLBase.solve!(cache, defaultalg(A, b, assump), args...; kwargs...)
281 end
282
283 function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
284 verbose::Bool, assump::OperatorAssumptions)
285 init_cacheval(defaultalg(A, b, assump), A, b, u, Pl, Pr, maxiters, abstol, reltol,
286 verbose,
287 assump)
288 end
289
290 """
291 cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...)
292 """
293 119 (41 %)
119 (41 %) samples spent in init_cacheval
119 (100 %) (incl.) when called from #init#3 line 167
119 (100 %) samples spent calling macro expansion
@generated function init_cacheval(alg::DefaultLinearSolver, A, b, u, Pl, Pr, maxiters::Int,
294 abstol, reltol,
295 verbose::Bool, assump::OperatorAssumptions)
296 caches = map(first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))) do alg
297 if alg === :KrylovJL_GMRES
298 quote
299 if A isa Matrix || A isa SparseMatrixCSC
300 nothing
301 else
302 init_cacheval($(algchoice_to_alg(alg)), A, b, u, Pl, Pr, maxiters,
303 abstol, reltol,
304 verbose,
305 assump)
306 end
307 end
308 else
309 quote
310 119 (41 %)
119 (41 %) samples spent in macro expansion
119 (100 %) (incl.) when called from init_cacheval line 293
119 (100 %) samples spent calling init_cacheval
init_cacheval($(algchoice_to_alg(alg)), A, b, u, Pl, Pr, maxiters, abstol,
311 reltol,
312 verbose,
313 assump)
314 end
315 end
316 end
317 Expr(:call, :DefaultLinearSolverInit, caches...)
318 end
319
320 function defaultalg_symbol(::Type{T}) where {T}
321 Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
322 end
323 defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization
324
325 defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
326
327 """
328 if alg.alg === DefaultAlgorithmChoice.LUFactorization
329 SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
330 else
331 ...
332 end
333 """
334 138 (48 %)
138 (48 %) samples spent in solve!
69 (50 %) (incl.) when called from solve! line 334
69 (50 %) (incl.) when called from #solve!#6 line 189
69 (50 %) samples spent calling macro expansion
69 (50 %) samples spent calling #solve!#87
@generated function SciMLBase.solve!(cache::LinearCache, alg::DefaultLinearSolver,
335 args...;
336 assump::OperatorAssumptions = OperatorAssumptions(),
337 kwargs...)
338 ex = :()
339 for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
340 newex = quote
341 69 (24 %)
69 (24 %) samples spent in macro expansion
69 (100 %) (incl.) when called from #solve!#87 line 334
69 (100 %) samples spent calling solve!
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
342 SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
343 retcode = sol.retcode,
344 iters = sol.iters, stats = sol.stats)
345 end
346 ex = if ex == :()
347 Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex,
348 :(error("Algorithm Choice not Allowed")))
349 else
350 Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex, ex)
351 end
352 end
353 ex = Expr(:if, ex.args...)
354 end
355
356 """
357 ```
358 elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
359 (cache.cacheval.LUFactorization)' \\ dy
360 else
361 ...
362 end
363 ```
364 """
365 @generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
366 ex = :()
367 for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
368 newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
369 DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
370 DefaultAlgorithmChoice.RFLUFactorization))
371 quote
372 getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy
373 end
374 elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
375 DefaultAlgorithmChoice.QRFactorization,
376 DefaultAlgorithmChoice.KLUFactorization,
377 DefaultAlgorithmChoice.UMFPACKFactorization,
378 DefaultAlgorithmChoice.LDLtFactorization,
379 DefaultAlgorithmChoice.SparspakFactorization,
380 DefaultAlgorithmChoice.BunchKaufmanFactorization,
381 DefaultAlgorithmChoice.CHOLMODFactorization,
382 DefaultAlgorithmChoice.SVDFactorization,
383 DefaultAlgorithmChoice.CholeskyFactorization,
384 DefaultAlgorithmChoice.NormalCholeskyFactorization,
385 DefaultAlgorithmChoice.QRFactorizationPivoted,
386 DefaultAlgorithmChoice.GenericLUFactorization))
387 quote
388 getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
389 end
390 elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
391 quote
392 invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
393 solve(invprob, cache.alg;
394 abstol = cache.val.abstol,
395 reltol = cache.val.reltol,
396 verbose = cache.val.verbose)
397 end
398 else
399 quote
400 error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
401 end
402 end
403
404 ex = if ex == :()
405 Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex,
406 :(error("Algorithm Choice not Allowed")))
407 else
408 Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex)
409 end
410 end
411 ex = Expr(:if, ex.args...)
412 end