Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | """ | ||
2 | ```julia | ||
3 | MKLLUFactorization() | ||
4 | ``` | ||
5 | |||
6 | A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace | ||
7 | to avoid allocations and does not require libblastrampoline. | ||
8 | """ | ||
9 | struct MKLLUFactorization <: AbstractFactorization end | ||
10 | |||
11 | function getrf!(A::AbstractMatrix{<:ComplexF64}; | ||
12 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
13 | info = Ref{BlasInt}(), | ||
14 | check = false) | ||
15 | require_one_based_indexing(A) | ||
16 | check && chkfinite(A) | ||
17 | chkstride1(A) | ||
18 | m, n = size(A) | ||
19 | lda = max(1, stride(A, 2)) | ||
20 | if isempty(ipiv) | ||
21 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
22 | end | ||
23 | ccall((@blasfunc(zgetrf_), MKL_jll.libmkl_rt), Cvoid, | ||
24 | (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, | ||
25 | Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
26 | m, n, A, lda, ipiv, info) | ||
27 | chkargsok(info[]) | ||
28 | A, ipiv, info[], info #Error code is stored in LU factorization type | ||
29 | end | ||
30 | |||
31 | function getrf!(A::AbstractMatrix{<:ComplexF32}; | ||
32 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
33 | info = Ref{BlasInt}(), | ||
34 | check = false) | ||
35 | require_one_based_indexing(A) | ||
36 | check && chkfinite(A) | ||
37 | chkstride1(A) | ||
38 | m, n = size(A) | ||
39 | lda = max(1, stride(A, 2)) | ||
40 | if isempty(ipiv) | ||
41 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
42 | end | ||
43 | ccall((@blasfunc(cgetrf_), MKL_jll.libmkl_rt), Cvoid, | ||
44 | (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, | ||
45 | Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
46 | m, n, A, lda, ipiv, info) | ||
47 | chkargsok(info[]) | ||
48 | A, ipiv, info[], info #Error code is stored in LU factorization type | ||
49 | end | ||
50 | |||
51 | 63 (22 %) |
126 (44 %)
samples spent in getrf!
63 (50 %) (incl.) when called from #solve!#47 line 218 63 (50 %) (ex.), 63 (50 %) (incl.) when called from getrf! line 51
63 (100 %)
samples spent calling
#getrf!#41
function getrf!(A::AbstractMatrix{<:Float64};
|
|
52 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
53 | info = Ref{BlasInt}(), | ||
54 | check = false) | ||
55 | require_one_based_indexing(A) | ||
56 | check && chkfinite(A) | ||
57 | chkstride1(A) | ||
58 | m, n = size(A) | ||
59 | lda = max(1, stride(A, 2)) | ||
60 | if isempty(ipiv) | ||
61 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
62 | end | ||
63 | 63 (22 %) | 63 (22 %) | ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid, |
64 | (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, | ||
65 | Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
66 | m, n, A, lda, ipiv, info) | ||
67 | chkargsok(info[]) | ||
68 | A, ipiv, info[], info #Error code is stored in LU factorization type | ||
69 | end | ||
70 | |||
71 | function getrf!(A::AbstractMatrix{<:Float32}; | ||
72 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), | ||
73 | info = Ref{BlasInt}(), | ||
74 | check = false) | ||
75 | require_one_based_indexing(A) | ||
76 | check && chkfinite(A) | ||
77 | chkstride1(A) | ||
78 | m, n = size(A) | ||
79 | lda = max(1, stride(A, 2)) | ||
80 | if isempty(ipiv) | ||
81 | ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) | ||
82 | end | ||
83 | ccall((@blasfunc(sgetrf_), MKL_jll.libmkl_rt), Cvoid, | ||
84 | (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, | ||
85 | Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), | ||
86 | m, n, A, lda, ipiv, info) | ||
87 | chkargsok(info[]) | ||
88 | A, ipiv, info[], info #Error code is stored in LU factorization type | ||
89 | end | ||
90 | |||
91 | function getrs!(trans::AbstractChar, | ||
92 | A::AbstractMatrix{<:ComplexF64}, | ||
93 | ipiv::AbstractVector{BlasInt}, | ||
94 | B::AbstractVecOrMat{<:ComplexF64}; | ||
95 | info = Ref{BlasInt}()) | ||
96 | require_one_based_indexing(A, ipiv, B) | ||
97 | LinearAlgebra.LAPACK.chktrans(trans) | ||
98 | chkstride1(A, B, ipiv) | ||
99 | n = LinearAlgebra.checksquare(A) | ||
100 | if n != size(B, 1) | ||
101 | throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
102 | end | ||
103 | if n != length(ipiv) | ||
104 | throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
105 | end | ||
106 | nrhs = size(B, 2) | ||
107 | ccall(("zgetrs_", MKL_jll.libmkl_rt), Cvoid, | ||
108 | (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, | ||
109 | Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
110 | trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
111 | 1) | ||
112 | LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
113 | B | ||
114 | end | ||
115 | |||
116 | function getrs!(trans::AbstractChar, | ||
117 | A::AbstractMatrix{<:ComplexF32}, | ||
118 | ipiv::AbstractVector{BlasInt}, | ||
119 | B::AbstractVecOrMat{<:ComplexF32}; | ||
120 | info = Ref{BlasInt}()) | ||
121 | require_one_based_indexing(A, ipiv, B) | ||
122 | LinearAlgebra.LAPACK.chktrans(trans) | ||
123 | chkstride1(A, B, ipiv) | ||
124 | n = LinearAlgebra.checksquare(A) | ||
125 | if n != size(B, 1) | ||
126 | throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
127 | end | ||
128 | if n != length(ipiv) | ||
129 | throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
130 | end | ||
131 | nrhs = size(B, 2) | ||
132 | ccall(("cgetrs_", MKL_jll.libmkl_rt), Cvoid, | ||
133 | (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, | ||
134 | Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
135 | trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
136 | 1) | ||
137 | LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
138 | B | ||
139 | end | ||
140 | |||
141 | function getrs!(trans::AbstractChar, | ||
142 | A::AbstractMatrix{<:Float64}, | ||
143 | ipiv::AbstractVector{BlasInt}, | ||
144 | B::AbstractVecOrMat{<:Float64}; | ||
145 | info = Ref{BlasInt}()) | ||
146 | require_one_based_indexing(A, ipiv, B) | ||
147 | LinearAlgebra.LAPACK.chktrans(trans) | ||
148 | chkstride1(A, B, ipiv) | ||
149 | n = LinearAlgebra.checksquare(A) | ||
150 | if n != size(B, 1) | ||
151 | throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
152 | end | ||
153 | if n != length(ipiv) | ||
154 | throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
155 | end | ||
156 | nrhs = size(B, 2) | ||
157 | ccall(("dgetrs_", MKL_jll.libmkl_rt), Cvoid, | ||
158 | (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, | ||
159 | Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
160 | trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
161 | 1) | ||
162 | LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
163 | B | ||
164 | end | ||
165 | |||
166 | function getrs!(trans::AbstractChar, | ||
167 | A::AbstractMatrix{<:Float32}, | ||
168 | ipiv::AbstractVector{BlasInt}, | ||
169 | B::AbstractVecOrMat{<:Float32}; | ||
170 | info = Ref{BlasInt}()) | ||
171 | require_one_based_indexing(A, ipiv, B) | ||
172 | LinearAlgebra.LAPACK.chktrans(trans) | ||
173 | chkstride1(A, B, ipiv) | ||
174 | n = LinearAlgebra.checksquare(A) | ||
175 | if n != size(B, 1) | ||
176 | throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) | ||
177 | end | ||
178 | if n != length(ipiv) | ||
179 | throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) | ||
180 | end | ||
181 | nrhs = size(B, 2) | ||
182 | ccall(("sgetrs_", MKL_jll.libmkl_rt), Cvoid, | ||
183 | (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, | ||
184 | Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), | ||
185 | trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, | ||
186 | 1) | ||
187 | LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) | ||
188 | B | ||
189 | end | ||
190 | |||
191 | default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false | ||
192 | default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false | ||
193 | |||
194 | const PREALLOCATED_MKL_LU = begin | ||
195 | A = rand(0, 0) | ||
196 | luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}() | ||
197 | end | ||
198 | |||
199 | function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr, | ||
200 | maxiters::Int, abstol, reltol, verbose::Bool, | ||
201 | assumptions::OperatorAssumptions) | ||
202 | PREALLOCATED_MKL_LU | ||
203 | end | ||
204 | |||
205 | function LinearSolve.init_cacheval(alg::MKLLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr, | ||
206 | maxiters::Int, abstol, reltol, verbose::Bool, | ||
207 | assumptions::OperatorAssumptions) | ||
208 | A = rand(eltype(A), 0, 0) | ||
209 | ArrayInterface.lu_instance(A), Ref{BlasInt}() | ||
210 | end | ||
211 | |||
212 | 69 (24 %) | function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; | |
213 | kwargs...) | ||
214 | A = cache.A | ||
215 | A = convert(AbstractMatrix, A) | ||
216 | if cache.isfresh | ||
217 | cacheval = @get_cacheval(cache, :MKLLUFactorization) | ||
218 | 63 (22 %) |
63 (100 %)
samples spent calling
getrf!
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
|
|
219 | fact = LU(res[1:3]...), res[4] | ||
220 | cache.cacheval = fact | ||
221 | cache.isfresh = false | ||
222 | end | ||
223 | |||
224 | 6 (2 %) |
6 (100 %)
samples spent calling
ldiv!
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
|
|
225 | SciMLBase.build_linear_solution(alg, y, nothing, cache) | ||
226 | |||
227 | #= | ||
228 | A, info = @get_cacheval(cache, :MKLLUFactorization) | ||
229 | LinearAlgebra.require_one_based_indexing(cache.u, cache.b) | ||
230 | m, n = size(A, 1), size(A, 2) | ||
231 | if m > n | ||
232 | Bc = copy(cache.b) | ||
233 | getrs!('N', A.factors, A.ipiv, Bc; info) | ||
234 | return copyto!(cache.u, 1, Bc, 1, n) | ||
235 | else | ||
236 | copyto!(cache.u, cache.b) | ||
237 | getrs!('N', A.factors, A.ipiv, cache.u; info) | ||
238 | end | ||
239 | |||
240 | SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) | ||
241 | =# | ||
242 | end |