Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | # This file is a part of Julia. License is MIT: https://julialang.org/license | ||
2 | |||
3 | # matmul.jl: Everything to do with dense matrix multiplication | ||
4 | |||
5 | # Matrix-matrix multiplication | ||
6 | |||
7 | AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}} | ||
8 | StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}} | ||
9 | StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}} | ||
10 | |||
11 | matprod(x, y) = x*y + x*y | ||
12 | |||
13 | # dot products | ||
14 | |||
15 | dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasReal} = BLAS.dot(x, y) | ||
16 | dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasComplex} = BLAS.dotc(x, y) | ||
17 | |||
18 | function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer} | ||
19 | if length(rx) != length(ry) | ||
20 | throw(DimensionMismatch(lazy"length of rx, $(length(rx)), does not equal length of ry, $(length(ry))")) | ||
21 | end | ||
22 | if minimum(rx) < 1 || maximum(rx) > length(x) | ||
23 | throw(BoundsError(x, rx)) | ||
24 | end | ||
25 | if minimum(ry) < 1 || maximum(ry) > length(y) | ||
26 | throw(BoundsError(y, ry)) | ||
27 | end | ||
28 | GC.@preserve x y BLAS.dot(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry)) | ||
29 | end | ||
30 | |||
31 | function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasComplex,TI<:Integer} | ||
32 | if length(rx) != length(ry) | ||
33 | throw(DimensionMismatch(lazy"length of rx, $(length(rx)), does not equal length of ry, $(length(ry))")) | ||
34 | end | ||
35 | if minimum(rx) < 1 || maximum(rx) > length(x) | ||
36 | throw(BoundsError(x, rx)) | ||
37 | end | ||
38 | if minimum(ry) < 1 || maximum(ry) > length(y) | ||
39 | throw(BoundsError(y, ry)) | ||
40 | end | ||
41 | GC.@preserve x y BLAS.dotc(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry)) | ||
42 | end | ||
43 | |||
44 | function *(transx::Transpose{<:Any,<:StridedVector{T}}, y::StridedVector{T}) where {T<:BlasComplex} | ||
45 | x = transx.parent | ||
46 | return BLAS.dotu(x, y) | ||
47 | end | ||
48 | |||
49 | # Matrix-vector multiplication | ||
50 | function (*)(A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{S}) where {T<:BlasFloat,S<:Real} | ||
51 | TS = promote_op(matprod, T, S) | ||
52 | y = isconcretetype(TS) ? convert(AbstractVector{TS}, x) : x | ||
53 | mul!(similar(x, TS, size(A,1)), A, y) | ||
54 | end | ||
55 | function (*)(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T,S} | ||
56 | TS = promote_op(matprod, T, S) | ||
57 | mul!(similar(x, TS, axes(A,1)), A, x) | ||
58 | end | ||
59 | |||
60 | # these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case): | ||
61 | (*)(a::AbstractVector, tB::TransposeAbsMat) = reshape(a, length(a), 1) * tB | ||
62 | (*)(a::AbstractVector, adjB::AdjointAbsMat) = reshape(a, length(a), 1) * adjB | ||
63 | (*)(a::AbstractVector, B::AbstractMatrix) = reshape(a, length(a), 1) * B | ||
64 | |||
65 | @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, | ||
66 | alpha::Number, beta::Number) = | ||
67 | generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta)) | ||
68 | # BLAS cases | ||
69 | # equal eltypes | ||
70 | @inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, | ||
71 | _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} = | ||
72 | gemv!(y, tA, A, x, _add.alpha, _add.beta) | ||
73 | # Real (possibly transposed) matrix times complex vector. | ||
74 | # Multiply the matrix with the real and imaginary parts separately | ||
75 | @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, | ||
76 | _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = | ||
77 | gemv!(y, tA, A, x, _add.alpha, _add.beta) | ||
78 | # Complex matrix times real vector. | ||
79 | # Reinterpret the matrix as a real matrix and do real matvec computation. | ||
80 | # works only in cooperation with BLAS when A is untransposed (tA == 'N') | ||
81 | # but that check is included in gemv! anyway | ||
82 | @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, | ||
83 | _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = | ||
84 | gemv!(y, tA, A, x, _add.alpha, _add.beta) | ||
85 | |||
86 | # Vector-Matrix multiplication | ||
87 | (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' | ||
88 | (*)(x::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A)*transpose(x)) | ||
89 | |||
90 | # Matrix-matrix multiplication | ||
91 | """ | ||
92 | *(A::AbstractMatrix, B::AbstractMatrix) | ||
93 | |||
94 | Matrix multiplication. | ||
95 | |||
96 | # Examples | ||
97 | ```jldoctest | ||
98 | julia> [1 1; 0 1] * [1 0; 1 1] | ||
99 | 2×2 Matrix{Int64}: | ||
100 | 2 1 | ||
101 | 1 1 | ||
102 | ``` | ||
103 | """ | ||
104 | function (*)(A::AbstractMatrix, B::AbstractMatrix) | ||
105 | TS = promote_op(matprod, eltype(A), eltype(B)) | ||
106 | mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B) | ||
107 | end | ||
108 | # optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64}) | ||
109 | # but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal}) | ||
110 | # which is better handled by reinterpreting rather than promotion | ||
111 | function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal}) | ||
112 | TS = promote_type(eltype(A), eltype(B)) | ||
113 | 31 (11 %) |
31 (100 %)
samples spent calling
mul!
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
|
|
114 | wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), | ||
115 | wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) | ||
116 | end | ||
117 | function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex}) | ||
118 | TS = promote_type(eltype(A), eltype(B)) | ||
119 | mul!(similar(B, TS, (size(A, 1), size(B, 2))), | ||
120 | wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), | ||
121 | wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) | ||
122 | end | ||
123 | |||
124 | # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the | ||
125 | # first matrix as a real matrix and carry out real matrix matrix multiply | ||
126 | function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal}) | ||
127 | TS = promote_type(eltype(A), eltype(B)) | ||
128 | mul!(similar(B, TS, (size(A, 1), size(B, 2))), | ||
129 | convert(AbstractArray{TS}, A), | ||
130 | wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) | ||
131 | end | ||
132 | function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal}) | ||
133 | TS = promote_type(eltype(A), eltype(B)) | ||
134 | mul!(similar(B, TS, (size(A, 1), size(B, 2))), | ||
135 | copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below | ||
136 | wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) | ||
137 | end | ||
138 | # the following case doesn't seem to benefit from the translation A*B = (B' * A')' | ||
139 | function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) | ||
140 | temp = real(B) | ||
141 | R = A * temp | ||
142 | temp .= imag.(B) | ||
143 | I = A * temp | ||
144 | Complex.(R, I) | ||
145 | end | ||
146 | (*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A))) | ||
147 | (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A))) | ||
148 | |||
149 | """ | ||
150 | muladd(A, y, z) | ||
151 | |||
152 | Combined multiply-add, `A*y .+ z`, for matrix-matrix or matrix-vector multiplication. | ||
153 | The result is always the same size as `A*y`, but `z` may be smaller, or a scalar. | ||
154 | |||
155 | !!! compat "Julia 1.6" | ||
156 | These methods require Julia 1.6 or later. | ||
157 | |||
158 | # Examples | ||
159 | ```jldoctest | ||
160 | julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; z=[0, 100]; | ||
161 | |||
162 | julia> muladd(A, B, z) | ||
163 | 2×2 Matrix{Float64}: | ||
164 | 3.0 3.0 | ||
165 | 107.0 107.0 | ||
166 | ``` | ||
167 | """ | ||
168 | function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray}) | ||
169 | Ay = A * y | ||
170 | for d in 1:ndims(Ay) | ||
171 | # Same error as Ay .+= z would give, to match StridedMatrix method: | ||
172 | size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination")) | ||
173 | end | ||
174 | for d in ndims(Ay)+1:ndims(z) | ||
175 | # Similar error to what Ay + z would give, to match (Any,Any,Any) method: | ||
176 | size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ", | ||
177 | axes(z), ", must have singleton at dim ", d))) | ||
178 | end | ||
179 | Ay .+ z | ||
180 | end | ||
181 | |||
182 | function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray}) | ||
183 | if size(z,1) > length(u) || size(z,2) > length(v) | ||
184 | # Same error as (u*v) .+= z: | ||
185 | throw(DimensionMismatch("array could not be broadcast to match destination")) | ||
186 | end | ||
187 | for d in 3:ndims(z) | ||
188 | # Similar error to (u*v) + z: | ||
189 | size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ", | ||
190 | axes(z), ", must have singleton at dim ", d))) | ||
191 | end | ||
192 | (u .* v) .+ z | ||
193 | end | ||
194 | |||
195 | Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) = | ||
196 | muladd(A', x', z')' | ||
197 | Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) = | ||
198 | transpose(muladd(transpose(A), transpose(x), transpose(z))) | ||
199 | |||
200 | function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector}) | ||
201 | T = promote_type(eltype(A), eltype(y), eltype(z)) | ||
202 | C = similar(A, T, axes(A,1)) | ||
203 | C .= z | ||
204 | mul!(C, A, y, true, true) | ||
205 | end | ||
206 | |||
207 | function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat}) | ||
208 | T = promote_type(eltype(A), eltype(B), eltype(z)) | ||
209 | C = similar(A, T, axes(A,1), axes(B,2)) | ||
210 | C .= z | ||
211 | mul!(C, A, B, true, true) | ||
212 | end | ||
213 | |||
214 | """ | ||
215 | mul!(Y, A, B) -> Y | ||
216 | |||
217 | Calculates the matrix-matrix or matrix-vector product ``AB`` and stores the result in `Y`, | ||
218 | overwriting the existing value of `Y`. Note that `Y` must not be aliased with either `A` or | ||
219 | `B`. | ||
220 | |||
221 | # Examples | ||
222 | ```jldoctest | ||
223 | julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; Y = similar(B); mul!(Y, A, B); | ||
224 | |||
225 | julia> Y | ||
226 | 2×2 Matrix{Float64}: | ||
227 | 3.0 3.0 | ||
228 | 7.0 7.0 | ||
229 | ``` | ||
230 | |||
231 | # Implementation | ||
232 | For custom matrix and vector types, it is recommended to implement | ||
233 | 5-argument `mul!` rather than implementing 3-argument `mul!` directly | ||
234 | if possible. | ||
235 | """ | ||
236 | @inline function mul!(C, A, B) | ||
237 | 31 (11 %) |
31 (100 %)
samples spent calling
mul!
return mul!(C, A, B, true, false)
|
|
238 | end | ||
239 | |||
240 | """ | ||
241 | mul!(C, A, B, α, β) -> C | ||
242 | |||
243 | Combined inplace matrix-matrix or matrix-vector multiply-add ``A B α + C β``. | ||
244 | The result is stored in `C` by overwriting it. Note that `C` must not be | ||
245 | aliased with either `A` or `B`. | ||
246 | |||
247 | !!! compat "Julia 1.3" | ||
248 | Five-argument `mul!` requires at least Julia 1.3. | ||
249 | |||
250 | # Examples | ||
251 | ```jldoctest | ||
252 | julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; C=[1.0 2.0; 3.0 4.0]; | ||
253 | |||
254 | julia> mul!(C, A, B, 100.0, 10.0) === C | ||
255 | true | ||
256 | |||
257 | julia> C | ||
258 | 2×2 Matrix{Float64}: | ||
259 | 310.0 320.0 | ||
260 | 730.0 740.0 | ||
261 | ``` | ||
262 | """ | ||
263 | 31 (11 %) |
31 (100 %)
samples spent calling
generic_matmatmul!
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
|
|
264 | generic_matmatmul!( | ||
265 | C, | ||
266 | wrapper_char(A), | ||
267 | wrapper_char(B), | ||
268 | _unwrap(A), | ||
269 | _unwrap(B), | ||
270 | MulAddMul(α, β) | ||
271 | ) | ||
272 | |||
273 | """ | ||
274 | rmul!(A, B) | ||
275 | |||
276 | Calculate the matrix-matrix product ``AB``, overwriting `A`, and return the result. | ||
277 | Here, `B` must be of special matrix type, like, e.g., [`Diagonal`](@ref), | ||
278 | [`UpperTriangular`](@ref) or [`LowerTriangular`](@ref), or of some orthogonal type, | ||
279 | see [`QR`](@ref). | ||
280 | |||
281 | # Examples | ||
282 | ```jldoctest | ||
283 | julia> A = [0 1; 1 0]; | ||
284 | |||
285 | julia> B = UpperTriangular([1 2; 0 3]); | ||
286 | |||
287 | julia> rmul!(A, B); | ||
288 | |||
289 | julia> A | ||
290 | 2×2 Matrix{Int64}: | ||
291 | 0 3 | ||
292 | 1 2 | ||
293 | |||
294 | julia> A = [1.0 2.0; 3.0 4.0]; | ||
295 | |||
296 | julia> F = qr([0 1; -1 0]); | ||
297 | |||
298 | julia> rmul!(A, F.Q) | ||
299 | 2×2 Matrix{Float64}: | ||
300 | 2.0 1.0 | ||
301 | 4.0 3.0 | ||
302 | ``` | ||
303 | """ | ||
304 | rmul!(A, B) | ||
305 | |||
306 | """ | ||
307 | lmul!(A, B) | ||
308 | |||
309 | Calculate the matrix-matrix product ``AB``, overwriting `B`, and return the result. | ||
310 | Here, `A` must be of special matrix type, like, e.g., [`Diagonal`](@ref), | ||
311 | [`UpperTriangular`](@ref) or [`LowerTriangular`](@ref), or of some orthogonal type, | ||
312 | see [`QR`](@ref). | ||
313 | |||
314 | # Examples | ||
315 | ```jldoctest | ||
316 | julia> B = [0 1; 1 0]; | ||
317 | |||
318 | julia> A = UpperTriangular([1 2; 0 3]); | ||
319 | |||
320 | julia> lmul!(A, B); | ||
321 | |||
322 | julia> B | ||
323 | 2×2 Matrix{Int64}: | ||
324 | 2 1 | ||
325 | 3 0 | ||
326 | |||
327 | julia> B = [1.0 2.0; 3.0 4.0]; | ||
328 | |||
329 | julia> F = qr([0 1; -1 0]); | ||
330 | |||
331 | julia> lmul!(F.Q, B) | ||
332 | 2×2 Matrix{Float64}: | ||
333 | 3.0 4.0 | ||
334 | 1.0 2.0 | ||
335 | ``` | ||
336 | """ | ||
337 | lmul!(A, B) | ||
338 | |||
339 | # THE one big BLAS dispatch | ||
340 | @inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, | ||
341 | _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} | ||
342 | if all(in(('N', 'T', 'C')), (tA, tB)) | ||
343 | if tA == 'T' && tB == 'N' && A === B | ||
344 | 31 (11 %) |
31 (100 %)
samples spent calling
syrk_wrapper!
return syrk_wrapper!(C, 'T', A, _add)
|
|
345 | elseif tA == 'N' && tB == 'T' && A === B | ||
346 | return syrk_wrapper!(C, 'N', A, _add) | ||
347 | elseif tA == 'C' && tB == 'N' && A === B | ||
348 | return herk_wrapper!(C, 'C', A, _add) | ||
349 | elseif tA == 'N' && tB == 'C' && A === B | ||
350 | return herk_wrapper!(C, 'N', A, _add) | ||
351 | else | ||
352 | return gemm_wrapper!(C, tA, tB, A, B, _add) | ||
353 | end | ||
354 | end | ||
355 | alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
356 | if alpha isa Union{Bool,T} && beta isa Union{Bool,T} | ||
357 | if (tA == 'S' || tA == 's') && tB == 'N' | ||
358 | return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) | ||
359 | elseif (tB == 'S' || tB == 's') && tA == 'N' | ||
360 | return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) | ||
361 | elseif (tA == 'H' || tA == 'h') && tB == 'N' | ||
362 | return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) | ||
363 | elseif (tB == 'H' || tB == 'h') && tA == 'N' | ||
364 | return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) | ||
365 | end | ||
366 | end | ||
367 | return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) | ||
368 | end | ||
369 | |||
370 | # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. | ||
371 | @inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, | ||
372 | _add::MulAddMul=MulAddMul()) where {T<:BlasReal} | ||
373 | if all(in(('N', 'T', 'C')), (tA, tB)) | ||
374 | gemm_wrapper!(C, tA, tB, A, B, _add) | ||
375 | else | ||
376 | _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) | ||
377 | end | ||
378 | end | ||
379 | |||
380 | |||
381 | # Supporting functions for matrix multiplication | ||
382 | |||
383 | # copy transposed(adjoint) of upper(lower) side-diagonals. Optionally include diagonal. | ||
384 | @inline function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false, diag::Bool=false) | ||
385 | 18 (6 %) | n = checksquare(A) | |
386 | off = diag ? 0 : 1 | ||
387 | if uplo == 'U' | ||
388 | for i = 1:n, j = (i+off):n | ||
389 | 4 (1 %) |
4 (100 %)
samples spent calling
setindex!
A[j,i] = conjugate ? adjoint(A[i,j]) : transpose(A[i,j])
|
|
390 | 14 (5 %) |
14 (100 %)
samples spent calling
iterate
end
|
|
391 | elseif uplo == 'L' | ||
392 | for i = 1:n, j = (i+off):n | ||
393 | A[i,j] = conjugate ? adjoint(A[j,i]) : transpose(A[j,i]) | ||
394 | end | ||
395 | else | ||
396 | throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo")) | ||
397 | end | ||
398 | A | ||
399 | end | ||
400 | |||
401 | function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T}, | ||
402 | α::Number=true, β::Number=false) where {T<:BlasFloat} | ||
403 | mA, nA = lapack_size(tA, A) | ||
404 | nA != length(x) && | ||
405 | throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) | ||
406 | mA != length(y) && | ||
407 | throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) | ||
408 | mA == 0 && return y | ||
409 | nA == 0 && return _rmul_or_fill!(y, β) | ||
410 | alpha, beta = promote(α, β, zero(T)) | ||
411 | if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
412 | stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
413 | !iszero(stride(x, 1)) && # We only check input's stride here. | ||
414 | if tA in ('N', 'T', 'C') | ||
415 | return BLAS.gemv!(tA, alpha, A, x, beta, y) | ||
416 | elseif tA in ('S', 's') | ||
417 | return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) | ||
418 | elseif tA in ('H', 'h') | ||
419 | return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) | ||
420 | end | ||
421 | end | ||
422 | if tA in ('S', 's', 'H', 'h') | ||
423 | # re-wrap again and use plain ('N') matvec mul algorithm, | ||
424 | # because _generic_matvecmul! can't handle the HermOrSym cases specifically | ||
425 | return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) | ||
426 | else | ||
427 | return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) | ||
428 | end | ||
429 | end | ||
430 | |||
431 | function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, | ||
432 | α::Number = true, β::Number = false) where {T<:BlasReal} | ||
433 | mA, nA = lapack_size(tA, A) | ||
434 | nA != length(x) && | ||
435 | throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) | ||
436 | mA != length(y) && | ||
437 | throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) | ||
438 | mA == 0 && return y | ||
439 | nA == 0 && return _rmul_or_fill!(y, β) | ||
440 | alpha, beta = promote(α, β, zero(T)) | ||
441 | if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
442 | stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
443 | stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y` | ||
444 | !iszero(stride(x, 1)) | ||
445 | BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) | ||
446 | return y | ||
447 | else | ||
448 | Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) | ||
449 | return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) | ||
450 | end | ||
451 | end | ||
452 | |||
453 | function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, | ||
454 | α::Number = true, β::Number = false) where {T<:BlasFloat} | ||
455 | mA, nA = lapack_size(tA, A) | ||
456 | nA != length(x) && | ||
457 | throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) | ||
458 | mA != length(y) && | ||
459 | throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) | ||
460 | mA == 0 && return y | ||
461 | nA == 0 && return _rmul_or_fill!(y, β) | ||
462 | alpha, beta = promote(α, β, zero(T)) | ||
463 | @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
464 | stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
465 | !iszero(stride(x, 1)) && tA in ('N', 'T', 'C') | ||
466 | xfl = reinterpret(reshape, T, x) # Use reshape here. | ||
467 | yfl = reinterpret(reshape, T, y) | ||
468 | BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) | ||
469 | BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) | ||
470 | return y | ||
471 | elseif tA in ('S', 's', 'H', 'h') | ||
472 | # re-wrap again and use plain ('N') matvec mul algorithm, | ||
473 | # because _generic_matvecmul! can't handle the HermOrSym cases specifically | ||
474 | return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) | ||
475 | else | ||
476 | return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) | ||
477 | end | ||
478 | end | ||
479 | |||
480 |
31 (11 %)
samples spent in syrk_wrapper!
function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
31 (100 %) (incl.) when called from generic_matmatmul! line 344 |
||
481 | _add = MulAddMul()) where {T<:BlasFloat} | ||
482 | nC = checksquare(C) | ||
483 | if tA == 'T' | ||
484 | (nA, mA) = size(A,1), size(A,2) | ||
485 | tAt = 'N' | ||
486 | else | ||
487 | (mA, nA) = size(A,1), size(A,2) | ||
488 | tAt = 'T' | ||
489 | end | ||
490 | if nC != mA | ||
491 | throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) | ||
492 | end | ||
493 | if mA == 0 || nA == 0 || iszero(_add.alpha) | ||
494 | return _rmul_or_fill!(C, _add.beta) | ||
495 | end | ||
496 | if mA == 2 && nA == 2 | ||
497 | return matmul2x2!(C, tA, tAt, A, A, _add) | ||
498 | end | ||
499 | if mA == 3 && nA == 3 | ||
500 | return matmul3x3!(C, tA, tAt, A, A, _add) | ||
501 | end | ||
502 | |||
503 | # BLAS.syrk! only updates symmetric C | ||
504 | # alternatively, make non-zero β a show-stopper for BLAS.syrk! | ||
505 | if iszero(_add.beta) || issymmetric(C) | ||
506 | alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
507 | if (alpha isa Union{Bool,T} && | ||
508 | beta isa Union{Bool,T} && | ||
509 | stride(A, 1) == stride(C, 1) == 1 && | ||
510 | stride(A, 2) >= size(A, 1) && | ||
511 | stride(C, 2) >= size(C, 1)) | ||
512 | 31 (11 %) | return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') | |
513 | end | ||
514 | end | ||
515 | return gemm_wrapper!(C, tA, tAt, A, A, _add) | ||
516 | end | ||
517 | |||
518 | function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, | ||
519 | _add = MulAddMul()) where {T<:BlasReal} | ||
520 | nC = checksquare(C) | ||
521 | if tA == 'C' | ||
522 | (nA, mA) = size(A,1), size(A,2) | ||
523 | tAt = 'N' | ||
524 | else | ||
525 | (mA, nA) = size(A,1), size(A,2) | ||
526 | tAt = 'C' | ||
527 | end | ||
528 | if nC != mA | ||
529 | throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) | ||
530 | end | ||
531 | if mA == 0 || nA == 0 || iszero(_add.alpha) | ||
532 | return _rmul_or_fill!(C, _add.beta) | ||
533 | end | ||
534 | if mA == 2 && nA == 2 | ||
535 | return matmul2x2!(C, tA, tAt, A, A, _add) | ||
536 | end | ||
537 | if mA == 3 && nA == 3 | ||
538 | return matmul3x3!(C, tA, tAt, A, A, _add) | ||
539 | end | ||
540 | |||
541 | # Result array does not need to be initialized as long as beta==0 | ||
542 | # C = Matrix{T}(undef, mA, mA) | ||
543 | |||
544 | if iszero(_add.beta) || issymmetric(C) | ||
545 | alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
546 | if (alpha isa Union{Bool,T} && | ||
547 | beta isa Union{Bool,T} && | ||
548 | stride(A, 1) == stride(C, 1) == 1 && | ||
549 | stride(A, 2) >= size(A, 1) && | ||
550 | stride(C, 2) >= size(C, 1)) | ||
551 | return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) | ||
552 | end | ||
553 | end | ||
554 | return gemm_wrapper!(C, tA, tAt, A, A, _add) | ||
555 | end | ||
556 | |||
557 | function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, | ||
558 | A::StridedVecOrMat{T}, | ||
559 | B::StridedVecOrMat{T}) where {T<:BlasFloat} | ||
560 | mA, nA = lapack_size(tA, A) | ||
561 | mB, nB = lapack_size(tB, B) | ||
562 | C = similar(B, T, mA, nB) | ||
563 | if all(in(('N', 'T', 'C')), (tA, tB)) | ||
564 | gemm_wrapper!(C, tA, tB, A, B) | ||
565 | else | ||
566 | _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) | ||
567 | end | ||
568 | end | ||
569 | |||
570 | function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, | ||
571 | A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, | ||
572 | _add = MulAddMul()) where {T<:BlasFloat} | ||
573 | mA, nA = lapack_size(tA, A) | ||
574 | mB, nB = lapack_size(tB, B) | ||
575 | |||
576 | if nA != mB | ||
577 | throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) | ||
578 | end | ||
579 | |||
580 | if C === A || B === C | ||
581 | throw(ArgumentError("output matrix must not be aliased with input matrix")) | ||
582 | end | ||
583 | |||
584 | if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) | ||
585 | if size(C) != (mA, nB) | ||
586 | throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) | ||
587 | end | ||
588 | return _rmul_or_fill!(C, _add.beta) | ||
589 | end | ||
590 | |||
591 | if mA == 2 && nA == 2 && nB == 2 | ||
592 | return matmul2x2!(C, tA, tB, A, B, _add) | ||
593 | end | ||
594 | if mA == 3 && nA == 3 && nB == 3 | ||
595 | return matmul3x3!(C, tA, tB, A, B, _add) | ||
596 | end | ||
597 | |||
598 | alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
599 | if (alpha isa Union{Bool,T} && | ||
600 | beta isa Union{Bool,T} && | ||
601 | stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && | ||
602 | stride(A, 2) >= size(A, 1) && | ||
603 | stride(B, 2) >= size(B, 1) && | ||
604 | stride(C, 2) >= size(C, 1)) | ||
605 | return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) | ||
606 | end | ||
607 | _generic_matmatmul!(C, tA, tB, A, B, _add) | ||
608 | end | ||
609 | |||
610 | function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, | ||
611 | A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, | ||
612 | _add = MulAddMul()) where {T<:BlasReal} | ||
613 | mA, nA = lapack_size(tA, A) | ||
614 | mB, nB = lapack_size(tB, B) | ||
615 | |||
616 | if nA != mB | ||
617 | throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) | ||
618 | end | ||
619 | |||
620 | if C === A || B === C | ||
621 | throw(ArgumentError("output matrix must not be aliased with input matrix")) | ||
622 | end | ||
623 | |||
624 | if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) | ||
625 | if size(C) != (mA, nB) | ||
626 | throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) | ||
627 | end | ||
628 | return _rmul_or_fill!(C, _add.beta) | ||
629 | end | ||
630 | |||
631 | if mA == 2 && nA == 2 && nB == 2 | ||
632 | return matmul2x2!(C, tA, tB, A, B, _add) | ||
633 | end | ||
634 | if mA == 3 && nA == 3 && nB == 3 | ||
635 | return matmul3x3!(C, tA, tB, A, B, _add) | ||
636 | end | ||
637 | |||
638 | alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
639 | |||
640 | # Make-sure reinterpret-based optimization is BLAS-compatible. | ||
641 | if (alpha isa Union{Bool,T} && | ||
642 | beta isa Union{Bool,T} && | ||
643 | stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && | ||
644 | stride(A, 2) >= size(A, 1) && | ||
645 | stride(B, 2) >= size(B, 1) && | ||
646 | stride(C, 2) >= size(C, 1) && tA == 'N') | ||
647 | BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) | ||
648 | return C | ||
649 | end | ||
650 | _generic_matmatmul!(C, tA, tB, A, B, _add) | ||
651 | end | ||
652 | |||
653 | # blas.jl defines matmul for floats; other integer and mixed precision | ||
654 | # cases are handled here | ||
655 | |||
656 | lapack_size(t::AbstractChar, M::AbstractVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1)) | ||
657 | |||
658 | function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) | ||
659 | if tM == 'N' | ||
660 | copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) | ||
661 | else | ||
662 | LinearAlgebra.copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) | ||
663 | tM == 'C' && conj!(@view B[ir_dest, jr_dest]) | ||
664 | end | ||
665 | B | ||
666 | end | ||
667 | |||
668 | function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) | ||
669 | if tM == 'N' | ||
670 | LinearAlgebra.copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) | ||
671 | else | ||
672 | copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src) | ||
673 | tM == 'C' && conj!(@view B[ir_dest, jr_dest]) | ||
674 | end | ||
675 | B | ||
676 | end | ||
677 | |||
678 | # TODO: It will be faster for large matrices to convert to float, | ||
679 | # call BLAS, and convert back to required type. | ||
680 | |||
681 | # NOTE: the generic version is also called as fallback for | ||
682 | # strides != 1 cases | ||
683 | |||
684 | @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, | ||
685 | _add::MulAddMul = MulAddMul()) | ||
686 | Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) | ||
687 | return _generic_matvecmul!(C, ta, Anew, B, _add) | ||
688 | end | ||
689 | |||
690 | function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, | ||
691 | _add::MulAddMul = MulAddMul()) | ||
692 | require_one_based_indexing(C, A, B) | ||
693 | @assert tA in ('N', 'T', 'C') | ||
694 | mB = length(B) | ||
695 | mA, nA = lapack_size(tA, A) | ||
696 | if mB != nA | ||
697 | throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB")) | ||
698 | end | ||
699 | if mA != length(C) | ||
700 | throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA")) | ||
701 | end | ||
702 | |||
703 | Astride = size(A, 1) | ||
704 | |||
705 | @inbounds begin | ||
706 | if tA == 'T' # fastest case | ||
707 | if nA == 0 | ||
708 | for k = 1:mA | ||
709 | _modify!(_add, false, C, k) | ||
710 | end | ||
711 | else | ||
712 | for k = 1:mA | ||
713 | aoffs = (k-1)*Astride | ||
714 | s = zero(A[aoffs + 1]*B[1] + A[aoffs + 1]*B[1]) | ||
715 | for i = 1:nA | ||
716 | s += transpose(A[aoffs+i]) * B[i] | ||
717 | end | ||
718 | _modify!(_add, s, C, k) | ||
719 | end | ||
720 | end | ||
721 | elseif tA == 'C' | ||
722 | if nA == 0 | ||
723 | for k = 1:mA | ||
724 | _modify!(_add, false, C, k) | ||
725 | end | ||
726 | else | ||
727 | for k = 1:mA | ||
728 | aoffs = (k-1)*Astride | ||
729 | s = zero(A[aoffs + 1]*B[1] + A[aoffs + 1]*B[1]) | ||
730 | for i = 1:nA | ||
731 | s += A[aoffs + i]'B[i] | ||
732 | end | ||
733 | _modify!(_add, s, C, k) | ||
734 | end | ||
735 | end | ||
736 | else # tA == 'N' | ||
737 | for i = 1:mA | ||
738 | if !iszero(_add.beta) | ||
739 | C[i] *= _add.beta | ||
740 | elseif mB == 0 | ||
741 | C[i] = false | ||
742 | else | ||
743 | C[i] = zero(A[i]*B[1] + A[i]*B[1]) | ||
744 | end | ||
745 | end | ||
746 | for k = 1:mB | ||
747 | aoffs = (k-1)*Astride | ||
748 | b = _add(B[k]) | ||
749 | for i = 1:mA | ||
750 | C[i] += A[aoffs + i] * b | ||
751 | end | ||
752 | end | ||
753 | end | ||
754 | end # @inbounds | ||
755 | C | ||
756 | end | ||
757 | |||
758 | function generic_matmatmul(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S}) where {T,S} | ||
759 | mA, nA = lapack_size(tA, A) | ||
760 | mB, nB = lapack_size(tB, B) | ||
761 | C = similar(B, promote_op(matprod, T, S), mA, nB) | ||
762 | generic_matmatmul!(C, tA, tB, A, B) | ||
763 | end | ||
764 | |||
765 | const tilebufsize = 10800 # Approximately 32k/3 | ||
766 | |||
767 | function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) | ||
768 | mA, nA = lapack_size(tA, A) | ||
769 | mB, nB = lapack_size(tB, B) | ||
770 | mC, nC = size(C) | ||
771 | |||
772 | if iszero(_add.alpha) | ||
773 | return _rmul_or_fill!(C, _add.beta) | ||
774 | end | ||
775 | if mA == nA == mB == nB == mC == nC == 2 | ||
776 | return matmul2x2!(C, tA, tB, A, B, _add) | ||
777 | end | ||
778 | if mA == nA == mB == nB == mC == nC == 3 | ||
779 | return matmul3x3!(C, tA, tB, A, B, _add) | ||
780 | end | ||
781 | A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) | ||
782 | B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) | ||
783 | _generic_matmatmul!(C, tA, tB, A, B, _add) | ||
784 | end | ||
785 | |||
786 | function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, | ||
787 | _add::MulAddMul) where {T,S,R} | ||
788 | @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') | ||
789 | require_one_based_indexing(C, A, B) | ||
790 | |||
791 | mA, nA = lapack_size(tA, A) | ||
792 | mB, nB = lapack_size(tB, B) | ||
793 | if mB != nA | ||
794 | throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)")) | ||
795 | end | ||
796 | if size(C,1) != mA || size(C,2) != nB | ||
797 | throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)")) | ||
798 | end | ||
799 | |||
800 | if iszero(_add.alpha) || isempty(A) || isempty(B) | ||
801 | return _rmul_or_fill!(C, _add.beta) | ||
802 | end | ||
803 | |||
804 | tile_size = 0 | ||
805 | if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N') | ||
806 | tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1))) | ||
807 | end | ||
808 | @inbounds begin | ||
809 | if tile_size > 0 | ||
810 | sz = (tile_size, tile_size) | ||
811 | Atile = Array{T}(undef, sz) | ||
812 | Btile = Array{S}(undef, sz) | ||
813 | |||
814 | z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1]) | ||
815 | z = convert(promote_type(typeof(z1), R), z1) | ||
816 | |||
817 | if mA < tile_size && nA < tile_size && nB < tile_size | ||
818 | copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA) | ||
819 | copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB) | ||
820 | for j = 1:nB | ||
821 | boff = (j-1)*tile_size | ||
822 | for i = 1:mA | ||
823 | aoff = (i-1)*tile_size | ||
824 | s = z | ||
825 | for k = 1:nA | ||
826 | s += Atile[aoff+k] * Btile[boff+k] | ||
827 | end | ||
828 | _modify!(_add, s, C, (i,j)) | ||
829 | end | ||
830 | end | ||
831 | else | ||
832 | Ctile = Array{R}(undef, sz) | ||
833 | for jb = 1:tile_size:nB | ||
834 | jlim = min(jb+tile_size-1,nB) | ||
835 | jlen = jlim-jb+1 | ||
836 | for ib = 1:tile_size:mA | ||
837 | ilim = min(ib+tile_size-1,mA) | ||
838 | ilen = ilim-ib+1 | ||
839 | fill!(Ctile, z) | ||
840 | for kb = 1:tile_size:nA | ||
841 | klim = min(kb+tile_size-1,mB) | ||
842 | klen = klim-kb+1 | ||
843 | copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim) | ||
844 | copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim) | ||
845 | for j=1:jlen | ||
846 | bcoff = (j-1)*tile_size | ||
847 | for i = 1:ilen | ||
848 | aoff = (i-1)*tile_size | ||
849 | s = z | ||
850 | for k = 1:klen | ||
851 | s += Atile[aoff+k] * Btile[bcoff+k] | ||
852 | end | ||
853 | Ctile[bcoff+i] += s | ||
854 | end | ||
855 | end | ||
856 | end | ||
857 | if isone(_add.alpha) && iszero(_add.beta) | ||
858 | copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen) | ||
859 | else | ||
860 | C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim]) | ||
861 | end | ||
862 | end | ||
863 | end | ||
864 | end | ||
865 | else | ||
866 | # Multiplication for non-plain-data uses the naive algorithm | ||
867 | if tA == 'N' | ||
868 | if tB == 'N' | ||
869 | for i = 1:mA, j = 1:nB | ||
870 | z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) | ||
871 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
872 | for k = 1:nA | ||
873 | Ctmp += A[i, k]*B[k, j] | ||
874 | end | ||
875 | _modify!(_add, Ctmp, C, (i,j)) | ||
876 | end | ||
877 | elseif tB == 'T' | ||
878 | for i = 1:mA, j = 1:nB | ||
879 | z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1])) | ||
880 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
881 | for k = 1:nA | ||
882 | Ctmp += A[i, k] * transpose(B[j, k]) | ||
883 | end | ||
884 | _modify!(_add, Ctmp, C, (i,j)) | ||
885 | end | ||
886 | else | ||
887 | for i = 1:mA, j = 1:nB | ||
888 | z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]') | ||
889 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
890 | for k = 1:nA | ||
891 | Ctmp += A[i, k]*B[j, k]' | ||
892 | end | ||
893 | _modify!(_add, Ctmp, C, (i,j)) | ||
894 | end | ||
895 | end | ||
896 | elseif tA == 'T' | ||
897 | if tB == 'N' | ||
898 | for i = 1:mA, j = 1:nB | ||
899 | z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j]) | ||
900 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
901 | for k = 1:nA | ||
902 | Ctmp += transpose(A[k, i]) * B[k, j] | ||
903 | end | ||
904 | _modify!(_add, Ctmp, C, (i,j)) | ||
905 | end | ||
906 | elseif tB == 'T' | ||
907 | for i = 1:mA, j = 1:nB | ||
908 | z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1])) | ||
909 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
910 | for k = 1:nA | ||
911 | Ctmp += transpose(A[k, i]) * transpose(B[j, k]) | ||
912 | end | ||
913 | _modify!(_add, Ctmp, C, (i,j)) | ||
914 | end | ||
915 | else | ||
916 | for i = 1:mA, j = 1:nB | ||
917 | z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]') | ||
918 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
919 | for k = 1:nA | ||
920 | Ctmp += transpose(A[k, i]) * adjoint(B[j, k]) | ||
921 | end | ||
922 | _modify!(_add, Ctmp, C, (i,j)) | ||
923 | end | ||
924 | end | ||
925 | else | ||
926 | if tB == 'N' | ||
927 | for i = 1:mA, j = 1:nB | ||
928 | z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j]) | ||
929 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
930 | for k = 1:nA | ||
931 | Ctmp += A[k, i]'B[k, j] | ||
932 | end | ||
933 | _modify!(_add, Ctmp, C, (i,j)) | ||
934 | end | ||
935 | elseif tB == 'T' | ||
936 | for i = 1:mA, j = 1:nB | ||
937 | z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1])) | ||
938 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
939 | for k = 1:nA | ||
940 | Ctmp += adjoint(A[k, i]) * transpose(B[j, k]) | ||
941 | end | ||
942 | _modify!(_add, Ctmp, C, (i,j)) | ||
943 | end | ||
944 | else | ||
945 | for i = 1:mA, j = 1:nB | ||
946 | z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]') | ||
947 | Ctmp = convert(promote_type(R, typeof(z2)), z2) | ||
948 | for k = 1:nA | ||
949 | Ctmp += A[k, i]'B[j, k]' | ||
950 | end | ||
951 | _modify!(_add, Ctmp, C, (i,j)) | ||
952 | end | ||
953 | end | ||
954 | end | ||
955 | end | ||
956 | end # @inbounds | ||
957 | C | ||
958 | end | ||
959 | |||
960 | |||
961 | # multiply 2x2 matrices | ||
962 | function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} | ||
963 | matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B) | ||
964 | end | ||
965 | |||
966 | function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, | ||
967 | _add::MulAddMul = MulAddMul()) | ||
968 | require_one_based_indexing(C, A, B) | ||
969 | if !(size(A) == size(B) == size(C) == (2,2)) | ||
970 | throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) | ||
971 | end | ||
972 | @inbounds begin | ||
973 | if tA == 'N' | ||
974 | A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2] | ||
975 | elseif tA == 'T' | ||
976 | # TODO making these lazy could improve perf | ||
977 | A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])) | ||
978 | A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])) | ||
979 | elseif tA == 'C' | ||
980 | # TODO making these lazy could improve perf | ||
981 | A11 = copy(A[1,1]'); A12 = copy(A[2,1]') | ||
982 | A21 = copy(A[1,2]'); A22 = copy(A[2,2]') | ||
983 | elseif tA == 'S' | ||
984 | A11 = symmetric(A[1,1], :U); A12 = A[1,2] | ||
985 | A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U) | ||
986 | elseif tA == 's' | ||
987 | A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])) | ||
988 | A21 = A[2,1]; A22 = symmetric(A[2,2], :L) | ||
989 | elseif tA == 'H' | ||
990 | A11 = hermitian(A[1,1], :U); A12 = A[1,2] | ||
991 | A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U) | ||
992 | else # if tA == 'h' | ||
993 | A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])) | ||
994 | A21 = A[2,1]; A22 = hermitian(A[2,2], :L) | ||
995 | end | ||
996 | if tB == 'N' | ||
997 | B11 = B[1,1]; B12 = B[1,2]; | ||
998 | B21 = B[2,1]; B22 = B[2,2] | ||
999 | elseif tB == 'T' | ||
1000 | # TODO making these lazy could improve perf | ||
1001 | B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])) | ||
1002 | B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])) | ||
1003 | elseif tB == 'C' | ||
1004 | # TODO making these lazy could improve perf | ||
1005 | B11 = copy(B[1,1]'); B12 = copy(B[2,1]') | ||
1006 | B21 = copy(B[1,2]'); B22 = copy(B[2,2]') | ||
1007 | elseif tB == 'S' | ||
1008 | B11 = symmetric(B[1,1], :U); B12 = B[1,2] | ||
1009 | B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U) | ||
1010 | elseif tB == 's' | ||
1011 | B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])) | ||
1012 | B21 = B[2,1]; B22 = symmetric(B[2,2], :L) | ||
1013 | elseif tB == 'H' | ||
1014 | B11 = hermitian(B[1,1], :U); B12 = B[1,2] | ||
1015 | B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U) | ||
1016 | else # if tB == 'h' | ||
1017 | B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])) | ||
1018 | B21 = B[2,1]; B22 = hermitian(B[2,2], :L) | ||
1019 | end | ||
1020 | _modify!(_add, A11*B11 + A12*B21, C, (1,1)) | ||
1021 | _modify!(_add, A11*B12 + A12*B22, C, (1,2)) | ||
1022 | _modify!(_add, A21*B11 + A22*B21, C, (2,1)) | ||
1023 | _modify!(_add, A21*B12 + A22*B22, C, (2,2)) | ||
1024 | end # inbounds | ||
1025 | C | ||
1026 | end | ||
1027 | |||
1028 | # Multiply 3x3 matrices | ||
1029 | function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S} | ||
1030 | matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B) | ||
1031 | end | ||
1032 | |||
1033 | function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, | ||
1034 | _add::MulAddMul = MulAddMul()) | ||
1035 | require_one_based_indexing(C, A, B) | ||
1036 | if !(size(A) == size(B) == size(C) == (3,3)) | ||
1037 | throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) | ||
1038 | end | ||
1039 | @inbounds begin | ||
1040 | if tA == 'N' | ||
1041 | A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3] | ||
1042 | A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3] | ||
1043 | A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3] | ||
1044 | elseif tA == 'T' | ||
1045 | # TODO making these lazy could improve perf | ||
1046 | A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) | ||
1047 | A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2])) | ||
1048 | A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3])) | ||
1049 | elseif tA == 'C' | ||
1050 | # TODO making these lazy could improve perf | ||
1051 | A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]') | ||
1052 | A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]') | ||
1053 | A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]') | ||
1054 | elseif tA == 'S' | ||
1055 | A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] | ||
1056 | A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3] | ||
1057 | A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U) | ||
1058 | elseif tA == 's' | ||
1059 | A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1])) | ||
1060 | A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2])) | ||
1061 | A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L) | ||
1062 | elseif tA == 'H' | ||
1063 | A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3] | ||
1064 | A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3] | ||
1065 | A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U) | ||
1066 | else # if tA == 'h' | ||
1067 | A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1])) | ||
1068 | A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2])) | ||
1069 | A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L) | ||
1070 | end | ||
1071 | |||
1072 | if tB == 'N' | ||
1073 | B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3] | ||
1074 | B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3] | ||
1075 | B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3] | ||
1076 | elseif tB == 'T' | ||
1077 | # TODO making these lazy could improve perf | ||
1078 | B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) | ||
1079 | B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2])) | ||
1080 | B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3])) | ||
1081 | elseif tB == 'C' | ||
1082 | # TODO making these lazy could improve perf | ||
1083 | B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]') | ||
1084 | B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]') | ||
1085 | B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]') | ||
1086 | elseif tB == 'S' | ||
1087 | B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] | ||
1088 | B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3] | ||
1089 | B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U) | ||
1090 | elseif tB == 's' | ||
1091 | B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1])) | ||
1092 | B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2])) | ||
1093 | B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L) | ||
1094 | elseif tB == 'H' | ||
1095 | B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3] | ||
1096 | B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3] | ||
1097 | B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U) | ||
1098 | else # if tB == 'h' | ||
1099 | B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1])) | ||
1100 | B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2])) | ||
1101 | B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L) | ||
1102 | end | ||
1103 | |||
1104 | _modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1)) | ||
1105 | _modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2)) | ||
1106 | _modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3)) | ||
1107 | |||
1108 | _modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1)) | ||
1109 | _modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2)) | ||
1110 | _modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3)) | ||
1111 | |||
1112 | _modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1)) | ||
1113 | _modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2)) | ||
1114 | _modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3)) | ||
1115 | end # inbounds | ||
1116 | C | ||
1117 | end | ||
1118 | |||
1119 | const RealOrComplex = Union{Real,Complex} | ||
1120 | |||
1121 | # Three-argument * | ||
1122 | """ | ||
1123 | *(A, B::AbstractMatrix, C) | ||
1124 | A * B * C * D | ||
1125 | |||
1126 | Chained multiplication of 3 or 4 matrices is done in the most efficient sequence, | ||
1127 | based on the sizes of the arrays. That is, the number of scalar multiplications needed | ||
1128 | for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)` | ||
1129 | to choose which of these to execute. | ||
1130 | |||
1131 | If the last factor is a vector, or the first a transposed vector, then it is efficient | ||
1132 | to deal with these first. In particular `x' * B * y` means `(x' * B) * y` | ||
1133 | for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this | ||
1134 | allocates an intermediate array. | ||
1135 | |||
1136 | If the first or last factor is a number, this will be fused with the matrix | ||
1137 | multiplication, using 5-arg [`mul!`](@ref). | ||
1138 | |||
1139 | See also [`muladd`](@ref), [`dot`](@ref). | ||
1140 | |||
1141 | !!! compat "Julia 1.7" | ||
1142 | These optimisations require at least Julia 1.7. | ||
1143 | """ | ||
1144 | *(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x) | ||
1145 | |||
1146 | *(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v | ||
1147 | *(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v) | ||
1148 | |||
1149 | *(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ) | ||
1150 | *(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ) | ||
1151 | *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) = | ||
1152 | mat_vec_scalar(B,C,α) | ||
1153 | *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) = | ||
1154 | mat_mat_scalar(B,C,α) | ||
1155 | |||
1156 | *(α::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv) | ||
1157 | *(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ) | ||
1158 | *(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C) | ||
1159 | |||
1160 | *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C) | ||
1161 | *(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C | ||
1162 | |||
1163 | function _tri_matmul(A,B,C,δ=nothing) | ||
1164 | n,m = size(A) | ||
1165 | # m,k == size(B) | ||
1166 | k,l = size(C) | ||
1167 | costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l | ||
1168 | costA_BC = m*k*l + n*m*l # m*l + n*l | ||
1169 | if costA_BC < costAB_C | ||
1170 | isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ) | ||
1171 | else | ||
1172 | isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ) | ||
1173 | end | ||
1174 | end | ||
1175 | |||
1176 | # Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar. | ||
1177 | |||
1178 | mat_vec_scalar(A, x, γ) = A * (x * γ) # fallback | ||
1179 | mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ) | ||
1180 | mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ | ||
1181 | |||
1182 | function _mat_vec_scalar(A, x, γ) | ||
1183 | T = promote_type(eltype(A), eltype(x), typeof(γ)) | ||
1184 | C = similar(A, T, axes(A,1)) | ||
1185 | mul!(C, A, x, γ, false) | ||
1186 | end | ||
1187 | |||
1188 | mat_mat_scalar(A, B, γ) = (A*B) * γ # fallback | ||
1189 | mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) = | ||
1190 | _mat_mat_scalar(A, B, γ) | ||
1191 | |||
1192 | function _mat_mat_scalar(A, B, γ) | ||
1193 | T = promote_type(eltype(A), eltype(B), typeof(γ)) | ||
1194 | C = similar(A, T, axes(A,1), axes(B,2)) | ||
1195 | mul!(C, A, B, γ, false) | ||
1196 | end | ||
1197 | |||
1198 | mat_mat_scalar(A::AdjointAbsVec, B, γ) = (γ' * (A * B)')' # preserving order, adjoint reverses | ||
1199 | mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) = | ||
1200 | mat_vec_scalar(B', A', γ')' | ||
1201 | |||
1202 | mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose(γ * transpose(A * B)) | ||
1203 | mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) = | ||
1204 | transpose(mat_vec_scalar(transpose(B), transpose(A), γ)) | ||
1205 | |||
1206 | |||
1207 | # Four-argument *, by type | ||
1208 | *(α::Number, β::Number, C::AbstractMatrix, x::AbstractVector) = (α*β) * C * x | ||
1209 | *(α::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) = (α*β) * C * D | ||
1210 | *(α::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x) | ||
1211 | *(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x) | ||
1212 | *(α::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) = | ||
1213 | (α*vt*C) * D # solves an ambiguity | ||
1214 | |||
1215 | *(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x * (γ*δ) | ||
1216 | *(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B * (γ*δ) | ||
1217 | *(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ) | ||
1218 | *(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ | ||
1219 | *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ | ||
1220 | |||
1221 | *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x) | ||
1222 | *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D | ||
1223 | *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x) | ||
1224 | |||
1225 | # Four-argument *, by size | ||
1226 | *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ) | ||
1227 | *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) = | ||
1228 | _tri_matmul(B,C,D,α) | ||
1229 | *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = | ||
1230 | _quad_matmul(A,B,C,D) | ||
1231 | |||
1232 | function _quad_matmul(A,B,C,D) | ||
1233 | c1 = _mul_cost((A,B),(C,D)) | ||
1234 | c2 = _mul_cost(((A,B),C),D) | ||
1235 | c3 = _mul_cost(A,(B,(C,D))) | ||
1236 | c4 = _mul_cost((A,(B,C)),D) | ||
1237 | c5 = _mul_cost(A,((B,C),D)) | ||
1238 | cmin = min(c1,c2,c3,c4,c5) | ||
1239 | if c1 == cmin | ||
1240 | (A*B) * (C*D) | ||
1241 | elseif c2 == cmin | ||
1242 | ((A*B) * C) * D | ||
1243 | elseif c3 == cmin | ||
1244 | A * (B * (C*D)) | ||
1245 | elseif c4 == cmin | ||
1246 | (A * (B*C)) * D | ||
1247 | else | ||
1248 | A * ((B*C) * D) | ||
1249 | end | ||
1250 | end | ||
1251 | @inline _mul_cost(A::AbstractMatrix) = 0 | ||
1252 | @inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B) | ||
1253 | @inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B))) | ||
1254 | @inline _mul_sizes(A::AbstractMatrix) = size(A) | ||
1255 | @inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B)) |