Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | needs_concrete_A(alg::DefaultLinearSolver) = true | ||
2 | mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, | ||
3 | T13, T14, T15, T16, T17, T18, T19} | ||
4 | LUFactorization::T1 | ||
5 | QRFactorization::T2 | ||
6 | DiagonalFactorization::T3 | ||
7 | DirectLdiv!::T4 | ||
8 | SparspakFactorization::T5 | ||
9 | KLUFactorization::T6 | ||
10 | UMFPACKFactorization::T7 | ||
11 | KrylovJL_GMRES::T8 | ||
12 | GenericLUFactorization::T9 | ||
13 | RFLUFactorization::T10 | ||
14 | LDLtFactorization::T11 | ||
15 | BunchKaufmanFactorization::T12 | ||
16 | CHOLMODFactorization::T13 | ||
17 | SVDFactorization::T14 | ||
18 | CholeskyFactorization::T15 | ||
19 | NormalCholeskyFactorization::T16 | ||
20 | AppleAccelerateLUFactorization::T17 | ||
21 | MKLLUFactorization::T18 | ||
22 | QRFactorizationPivoted::T19 | ||
23 | end | ||
24 | |||
25 | # Legacy fallback | ||
26 | # For SciML algorithms already using `defaultalg`, all assume square matrix. | ||
27 | defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true)) | ||
28 | |||
29 | function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, | ||
30 | assump::OperatorAssumptions{Bool}) | ||
31 | defaultalg(A.A, b, assump) | ||
32 | end | ||
33 | |||
34 | function defaultalg(A, b, assump::OperatorAssumptions{Nothing}) | ||
35 | issq = issquare(A) | ||
36 | defaultalg(A, b, OperatorAssumptions(issq, assump.condition)) | ||
37 | end | ||
38 | |||
39 | function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2} | ||
40 | if S1 == S2 | ||
41 | return LUFactorization() | ||
42 | else | ||
43 | return SVDFactorization() # QR(...) \ b is not defined currently | ||
44 | end | ||
45 | end | ||
46 | |||
47 | function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool}) | ||
48 | if assump.issq | ||
49 | DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) | ||
50 | else | ||
51 | DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) | ||
52 | end | ||
53 | end | ||
54 | |||
55 | function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{Bool}) | ||
56 | DefaultLinearSolver(DefaultAlgorithmChoice.LDLtFactorization) | ||
57 | end | ||
58 | function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{Bool}) | ||
59 | DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) | ||
60 | end | ||
61 | function defaultalg(A::Factorization, b, ::OperatorAssumptions{Bool}) | ||
62 | DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) | ||
63 | end | ||
64 | function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Bool}) | ||
65 | DefaultLinearSolver(DefaultAlgorithmChoice.DiagonalFactorization) | ||
66 | end | ||
67 | |||
68 | function defaultalg(A::Hermitian, b, ::OperatorAssumptions{Bool}) | ||
69 | DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization) | ||
70 | end | ||
71 | |||
72 | function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bool}) | ||
73 | DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization) | ||
74 | end | ||
75 | |||
76 | function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool}) | ||
77 | DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization) | ||
78 | end | ||
79 | |||
80 | function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b, | ||
81 | assump::OperatorAssumptions{Bool}) where {Tv, Ti} | ||
82 | if assump.issq | ||
83 | DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization) | ||
84 | else | ||
85 | error("Generic number sparse factorization for non-square is not currently handled") | ||
86 | end | ||
87 | end | ||
88 | |||
89 | @static if INCLUDE_SPARSE | ||
90 | function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b, | ||
91 | assump::OperatorAssumptions{Bool}) where {Ti} | ||
92 | if assump.issq | ||
93 | if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4 | ||
94 | DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization) | ||
95 | else | ||
96 | DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization) | ||
97 | end | ||
98 | else | ||
99 | DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) | ||
100 | end | ||
101 | end | ||
102 | end | ||
103 | |||
104 | function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool}) | ||
105 | if assump.condition === OperatorCondition.IllConditioned || !assump.issq | ||
106 | DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) | ||
107 | else | ||
108 | DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) | ||
109 | end | ||
110 | end | ||
111 | |||
112 | # A === nothing case | ||
113 | function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool}) | ||
114 | if assump.condition === OperatorCondition.IllConditioned || !assump.issq | ||
115 | DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) | ||
116 | else | ||
117 | DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) | ||
118 | end | ||
119 | end | ||
120 | |||
121 | # Ambiguity handling | ||
122 | function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray, | ||
123 | assump::OperatorAssumptions{Bool}) | ||
124 | if assump.condition === OperatorCondition.IllConditioned || !assump.issq | ||
125 | DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) | ||
126 | else | ||
127 | DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) | ||
128 | end | ||
129 | end | ||
130 | |||
131 | function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, | ||
132 | assump::OperatorAssumptions{Bool}) | ||
133 | if has_ldiv!(A) | ||
134 | return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) | ||
135 | elseif !assump.issq | ||
136 | m, n = size(A) | ||
137 | if m < n | ||
138 | DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_CRAIGMR) | ||
139 | else | ||
140 | DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_LSMR) | ||
141 | end | ||
142 | else | ||
143 | DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_GMRES) | ||
144 | end | ||
145 | end | ||
146 | |||
147 | # Allows A === nothing as a stand-in for dense matrix | ||
148 | function defaultalg(A, b, assump::OperatorAssumptions{Bool}) | ||
149 | alg = if assump.issq | ||
150 | # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when | ||
151 | # it makes sense according to the benchmarks, which is dependent on | ||
152 | # whether MKL or OpenBLAS is being used | ||
153 | if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix | ||
154 | if (A === nothing || | ||
155 | eltype(A) <: BLASELTYPES) && | ||
156 | ArrayInterface.can_setindex(b) && | ||
157 | (__conditioning(assump) === OperatorCondition.IllConditioned || | ||
158 | __conditioning(assump) === OperatorCondition.WellConditioned) | ||
159 | if length(b) <= 10 | ||
160 | DefaultAlgorithmChoice.GenericLUFactorization | ||
161 | elseif appleaccelerate_isavailable() | ||
162 | DefaultAlgorithmChoice.AppleAccelerateLUFactorization | ||
163 | elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || | ||
164 | (usemkl && length(b) <= 200)) && | ||
165 | (A === nothing ? eltype(b) <: Union{Float32, Float64} : | ||
166 | eltype(A) <: Union{Float32, Float64}) | ||
167 | DefaultAlgorithmChoice.RFLUFactorization | ||
168 | #elseif A === nothing || A isa Matrix | ||
169 | # alg = FastLUFactorization() | ||
170 | elseif usemkl | ||
171 | DefaultAlgorithmChoice.MKLLUFactorization | ||
172 | else | ||
173 | DefaultAlgorithmChoice.LUFactorization | ||
174 | end | ||
175 | elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned | ||
176 | DefaultAlgorithmChoice.QRFactorization | ||
177 | elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned | ||
178 | DefaultAlgorithmChoice.SVDFactorization | ||
179 | elseif usemkl && (A === nothing ? eltype(b) <: BLASELTYPES : | ||
180 | eltype(A) <: BLASELTYPES) | ||
181 | DefaultAlgorithmChoice.MKLLUFactorization | ||
182 | else | ||
183 | DefaultAlgorithmChoice.LUFactorization | ||
184 | end | ||
185 | |||
186 | # This catches the cases where a factorization overload could exist | ||
187 | # For example, BlockBandedMatrix | ||
188 | elseif A !== nothing && ArrayInterface.isstructured(A) | ||
189 | error("Special factorization not handled in current default algorithm") | ||
190 | |||
191 | # Not factorizable operator, default to only using A*x | ||
192 | else | ||
193 | DefaultAlgorithmChoice.KrylovJL_GMRES | ||
194 | end | ||
195 | elseif assump.condition === OperatorCondition.WellConditioned | ||
196 | DefaultAlgorithmChoice.NormalCholeskyFactorization | ||
197 | elseif assump.condition === OperatorCondition.IllConditioned | ||
198 | if is_underdetermined(A) | ||
199 | # Underdetermined | ||
200 | DefaultAlgorithmChoice.QRFactorizationPivoted | ||
201 | else | ||
202 | DefaultAlgorithmChoice.QRFactorization | ||
203 | end | ||
204 | elseif assump.condition === OperatorCondition.VeryIllConditioned | ||
205 | if is_underdetermined(A) | ||
206 | # Underdetermined | ||
207 | DefaultAlgorithmChoice.QRFactorizationPivoted | ||
208 | else | ||
209 | DefaultAlgorithmChoice.QRFactorization | ||
210 | end | ||
211 | elseif assump.condition === OperatorCondition.SuperIllConditioned | ||
212 | DefaultAlgorithmChoice.SVDFactorization | ||
213 | else | ||
214 | error("Special factorization not handled in current default algorithm") | ||
215 | end | ||
216 | DefaultLinearSolver(alg) | ||
217 | end | ||
218 | |||
219 | function algchoice_to_alg(alg::Symbol) | ||
220 | if alg === :SVDFactorization | ||
221 | SVDFactorization(false, LinearAlgebra.QRIteration()) | ||
222 | elseif alg === :LDLtFactorization | ||
223 | LDLtFactorization() | ||
224 | elseif alg === :LUFactorization | ||
225 | LUFactorization() | ||
226 | elseif alg === :MKLLUFactorization | ||
227 | MKLLUFactorization() | ||
228 | elseif alg === :QRFactorization | ||
229 | QRFactorization() | ||
230 | elseif alg === :DiagonalFactorization | ||
231 | DiagonalFactorization() | ||
232 | elseif alg === :DirectLdiv! | ||
233 | DirectLdiv!() | ||
234 | elseif alg === :SparspakFactorization | ||
235 | SparspakFactorization() | ||
236 | elseif alg === :KLUFactorization | ||
237 | KLUFactorization() | ||
238 | elseif alg === :UMFPACKFactorization | ||
239 | UMFPACKFactorization() | ||
240 | elseif alg === :KrylovJL_GMRES | ||
241 | KrylovJL_GMRES() | ||
242 | elseif alg === :GenericLUFactorization | ||
243 | GenericLUFactorization() | ||
244 | elseif alg === :RFLUFactorization | ||
245 | RFLUFactorization() | ||
246 | elseif alg === :BunchKaufmanFactorization | ||
247 | BunchKaufmanFactorization() | ||
248 | elseif alg === :CHOLMODFactorization | ||
249 | CHOLMODFactorization() | ||
250 | elseif alg === :CholeskyFactorization | ||
251 | CholeskyFactorization() | ||
252 | elseif alg === :NormalCholeskyFactorization | ||
253 | NormalCholeskyFactorization() | ||
254 | elseif alg === :AppleAccelerateLUFactorization | ||
255 | AppleAccelerateLUFactorization() | ||
256 | elseif alg === :QRFactorizationPivoted | ||
257 | @static if VERSION ≥ v"1.7beta" | ||
258 | QRFactorization(ColumnNorm()) | ||
259 | else | ||
260 | QRFactorization(Val(true)) | ||
261 | end | ||
262 | else | ||
263 | error("Algorithm choice symbol $alg not allowed in the default") | ||
264 | end | ||
265 | end | ||
266 | |||
267 | ## Catch high level interface | ||
268 | |||
269 | 71 (26 %) |
142 (52 %)
samples spent in init
71 (50 %) (incl.) when called from init line 269 71 (50 %) (incl.) when called from #linsolve_caches#92 line 181
71 (100 %)
samples spent calling
#init#82
function SciMLBase.init(prob::LinearProblem, alg::Nothing,
|
|
270 | args...; | ||
271 | assumptions = OperatorAssumptions(issquare(prob.A)), | ||
272 | kwargs...) | ||
273 | 71 (26 %) |
71 (100 %)
samples spent calling
init
SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...)
|
|
274 | end | ||
275 | |||
276 | function SciMLBase.solve!(cache::LinearCache, alg::Nothing, | ||
277 | args...; assump::OperatorAssumptions = OperatorAssumptions(), | ||
278 | kwargs...) | ||
279 | @unpack A, b = cache | ||
280 | SciMLBase.solve!(cache, defaultalg(A, b, assump), args...; kwargs...) | ||
281 | end | ||
282 | |||
283 | function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, | ||
284 | verbose::Bool, assump::OperatorAssumptions) | ||
285 | init_cacheval(defaultalg(A, b, assump), A, b, u, Pl, Pr, maxiters, abstol, reltol, | ||
286 | verbose, | ||
287 | assump) | ||
288 | end | ||
289 | |||
290 | """ | ||
291 | cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...) | ||
292 | """ | ||
293 | 71 (26 %) |
71 (100 %)
samples spent calling
macro expansion
@generated function init_cacheval(alg::DefaultLinearSolver, A, b, u, Pl, Pr, maxiters::Int,
|
|
294 | abstol, reltol, | ||
295 | verbose::Bool, assump::OperatorAssumptions) | ||
296 | caches = map(first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))) do alg | ||
297 | if alg === :KrylovJL_GMRES | ||
298 | quote | ||
299 | if A isa Matrix || A isa SparseMatrixCSC | ||
300 | nothing | ||
301 | else | ||
302 | init_cacheval($(algchoice_to_alg(alg)), A, b, u, Pl, Pr, maxiters, | ||
303 | abstol, reltol, | ||
304 | verbose, | ||
305 | assump) | ||
306 | end | ||
307 | end | ||
308 | else | ||
309 | quote | ||
310 | 71 (26 %) |
71 (26 %)
samples spent in macro expansion
71 (100 %) (incl.) when called from init_cacheval line 293
71 (100 %)
samples spent calling
init_cacheval
init_cacheval($(algchoice_to_alg(alg)), A, b, u, Pl, Pr, maxiters, abstol,
|
|
311 | reltol, | ||
312 | verbose, | ||
313 | assump) | ||
314 | end | ||
315 | end | ||
316 | end | ||
317 | Expr(:call, :DefaultLinearSolverInit, caches...) | ||
318 | end | ||
319 | |||
320 | function defaultalg_symbol(::Type{T}) where {T} | ||
321 | Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end]) | ||
322 | end | ||
323 | defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization | ||
324 | |||
325 | defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted | ||
326 | |||
327 | """ | ||
328 | if alg.alg === DefaultAlgorithmChoice.LUFactorization | ||
329 | SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...)) | ||
330 | else | ||
331 | ... | ||
332 | end | ||
333 | """ | ||
334 | 162 (60 %) |
162 (60 %)
samples spent in solve!
@generated function SciMLBase.solve!(cache::LinearCache, alg::DefaultLinearSolver,
81 (50 %) (incl.) when called from solve! line 334 81 (50 %) (incl.) when called from #solve!#6 line 189 |
|
335 | args...; | ||
336 | assump::OperatorAssumptions = OperatorAssumptions(), | ||
337 | kwargs...) | ||
338 | ex = :() | ||
339 | for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) | ||
340 | newex = quote | ||
341 | 81 (30 %) |
81 (100 %)
samples spent calling
solve!
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
|
|
342 | SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; | ||
343 | retcode = sol.retcode, | ||
344 | iters = sol.iters, stats = sol.stats) | ||
345 | end | ||
346 | ex = if ex == :() | ||
347 | Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex, | ||
348 | :(error("Algorithm Choice not Allowed"))) | ||
349 | else | ||
350 | Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex, ex) | ||
351 | end | ||
352 | end | ||
353 | ex = Expr(:if, ex.args...) | ||
354 | end | ||
355 | |||
356 | """ | ||
357 | ``` | ||
358 | elseif DefaultAlgorithmChoice.LUFactorization === cache.alg | ||
359 | (cache.cacheval.LUFactorization)' \\ dy | ||
360 | else | ||
361 | ... | ||
362 | end | ||
363 | ``` | ||
364 | """ | ||
365 | @generated function defaultalg_adjoint_eval(cache::LinearCache, dy) | ||
366 | ex = :() | ||
367 | for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) | ||
368 | newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization, | ||
369 | DefaultAlgorithmChoice.AppleAccelerateLUFactorization, | ||
370 | DefaultAlgorithmChoice.RFLUFactorization)) | ||
371 | quote | ||
372 | getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy | ||
373 | end | ||
374 | elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, | ||
375 | DefaultAlgorithmChoice.QRFactorization, | ||
376 | DefaultAlgorithmChoice.KLUFactorization, | ||
377 | DefaultAlgorithmChoice.UMFPACKFactorization, | ||
378 | DefaultAlgorithmChoice.LDLtFactorization, | ||
379 | DefaultAlgorithmChoice.SparspakFactorization, | ||
380 | DefaultAlgorithmChoice.BunchKaufmanFactorization, | ||
381 | DefaultAlgorithmChoice.CHOLMODFactorization, | ||
382 | DefaultAlgorithmChoice.SVDFactorization, | ||
383 | DefaultAlgorithmChoice.CholeskyFactorization, | ||
384 | DefaultAlgorithmChoice.NormalCholeskyFactorization, | ||
385 | DefaultAlgorithmChoice.QRFactorizationPivoted, | ||
386 | DefaultAlgorithmChoice.GenericLUFactorization)) | ||
387 | quote | ||
388 | getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy | ||
389 | end | ||
390 | elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,)) | ||
391 | quote | ||
392 | invprob = LinearSolve.LinearProblem(transpose(cache.A), dy) | ||
393 | solve(invprob, cache.alg; | ||
394 | abstol = cache.val.abstol, | ||
395 | reltol = cache.val.reltol, | ||
396 | verbose = cache.val.verbose) | ||
397 | end | ||
398 | else | ||
399 | quote | ||
400 | error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") | ||
401 | end | ||
402 | end | ||
403 | |||
404 | ex = if ex == :() | ||
405 | Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, | ||
406 | :(error("Algorithm Choice not Allowed"))) | ||
407 | else | ||
408 | Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex) | ||
409 | end | ||
410 | end | ||
411 | ex = Expr(:if, ex.args...) | ||
412 | end |