StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 12:59:22
File source code
Line Exclusive Inclusive Code
1 # This file is a part of Julia. License is MIT: https://julialang.org/license
2
3 # matmul.jl: Everything to do with dense matrix multiplication
4
5 # Matrix-matrix multiplication
6
7 AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
8 StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
9 StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}}
10
11 matprod(x, y) = x*y + x*y
12
13 # dot products
14
15 dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasReal} = BLAS.dot(x, y)
16 dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasComplex} = BLAS.dotc(x, y)
17
18 function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer}
19 if length(rx) != length(ry)
20 throw(DimensionMismatch(lazy"length of rx, $(length(rx)), does not equal length of ry, $(length(ry))"))
21 end
22 if minimum(rx) < 1 || maximum(rx) > length(x)
23 throw(BoundsError(x, rx))
24 end
25 if minimum(ry) < 1 || maximum(ry) > length(y)
26 throw(BoundsError(y, ry))
27 end
28 GC.@preserve x y BLAS.dot(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
29 end
30
31 function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasComplex,TI<:Integer}
32 if length(rx) != length(ry)
33 throw(DimensionMismatch(lazy"length of rx, $(length(rx)), does not equal length of ry, $(length(ry))"))
34 end
35 if minimum(rx) < 1 || maximum(rx) > length(x)
36 throw(BoundsError(x, rx))
37 end
38 if minimum(ry) < 1 || maximum(ry) > length(y)
39 throw(BoundsError(y, ry))
40 end
41 GC.@preserve x y BLAS.dotc(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
42 end
43
44 function *(transx::Transpose{<:Any,<:StridedVector{T}}, y::StridedVector{T}) where {T<:BlasComplex}
45 x = transx.parent
46 return BLAS.dotu(x, y)
47 end
48
49 # Matrix-vector multiplication
50 function (*)(A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{S}) where {T<:BlasFloat,S<:Real}
51 TS = promote_op(matprod, T, S)
52 y = isconcretetype(TS) ? convert(AbstractVector{TS}, x) : x
53 mul!(similar(x, TS, size(A,1)), A, y)
54 end
55 function (*)(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T,S}
56 TS = promote_op(matprod, T, S)
57 mul!(similar(x, TS, axes(A,1)), A, x)
58 end
59
60 # these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case):
61 (*)(a::AbstractVector, tB::TransposeAbsMat) = reshape(a, length(a), 1) * tB
62 (*)(a::AbstractVector, adjB::AdjointAbsMat) = reshape(a, length(a), 1) * adjB
63 (*)(a::AbstractVector, B::AbstractMatrix) = reshape(a, length(a), 1) * B
64
65 @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
66 alpha::Number, beta::Number) =
67 generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta))
68 # BLAS cases
69 # equal eltypes
70 @inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
71 _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} =
72 gemv!(y, tA, A, x, _add.alpha, _add.beta)
73 # Real (possibly transposed) matrix times complex vector.
74 # Multiply the matrix with the real and imaginary parts separately
75 @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
76 _add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
77 gemv!(y, tA, A, x, _add.alpha, _add.beta)
78 # Complex matrix times real vector.
79 # Reinterpret the matrix as a real matrix and do real matvec computation.
80 # works only in cooperation with BLAS when A is untransposed (tA == 'N')
81 # but that check is included in gemv! anyway
82 @inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
83 _add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
84 gemv!(y, tA, A, x, _add.alpha, _add.beta)
85
86 # Vector-Matrix multiplication
87 (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
88 (*)(x::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A)*transpose(x))
89
90 # Matrix-matrix multiplication
91 """
92 *(A::AbstractMatrix, B::AbstractMatrix)
93
94 Matrix multiplication.
95
96 # Examples
97 ```jldoctest
98 julia> [1 1; 0 1] * [1 0; 1 1]
99 2×2 Matrix{Int64}:
100 2 1
101 1 1
102 ```
103 """
104 function (*)(A::AbstractMatrix, B::AbstractMatrix)
105 TS = promote_op(matprod, eltype(A), eltype(B))
106 mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
107 end
108 # optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
109 # but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
110 # which is better handled by reinterpreting rather than promotion
111
31 (11 %) samples spent in *
31 (100 %) (incl.) when called from init_cacheval line 1013
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
112 TS = promote_type(eltype(A), eltype(B))
113 31 (11 %)
31 (100 %) samples spent calling mul!
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
114 wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
115 wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
116 end
117 function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
118 TS = promote_type(eltype(A), eltype(B))
119 mul!(similar(B, TS, (size(A, 1), size(B, 2))),
120 wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
121 wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
122 end
123
124 # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
125 # first matrix as a real matrix and carry out real matrix matrix multiply
126 function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
127 TS = promote_type(eltype(A), eltype(B))
128 mul!(similar(B, TS, (size(A, 1), size(B, 2))),
129 convert(AbstractArray{TS}, A),
130 wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
131 end
132 function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
133 TS = promote_type(eltype(A), eltype(B))
134 mul!(similar(B, TS, (size(A, 1), size(B, 2))),
135 copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
136 wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
137 end
138 # the following case doesn't seem to benefit from the translation A*B = (B' * A')'
139 function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
140 temp = real(B)
141 R = A * temp
142 temp .= imag.(B)
143 I = A * temp
144 Complex.(R, I)
145 end
146 (*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A)))
147 (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A)))
148
149 """
150 muladd(A, y, z)
151
152 Combined multiply-add, `A*y .+ z`, for matrix-matrix or matrix-vector multiplication.
153 The result is always the same size as `A*y`, but `z` may be smaller, or a scalar.
154
155 !!! compat "Julia 1.6"
156 These methods require Julia 1.6 or later.
157
158 # Examples
159 ```jldoctest
160 julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; z=[0, 100];
161
162 julia> muladd(A, B, z)
163 2×2 Matrix{Float64}:
164 3.0 3.0
165 107.0 107.0
166 ```
167 """
168 function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray})
169 Ay = A * y
170 for d in 1:ndims(Ay)
171 # Same error as Ay .+= z would give, to match StridedMatrix method:
172 size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination"))
173 end
174 for d in ndims(Ay)+1:ndims(z)
175 # Similar error to what Ay + z would give, to match (Any,Any,Any) method:
176 size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
177 axes(z), ", must have singleton at dim ", d)))
178 end
179 Ay .+ z
180 end
181
182 function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray})
183 if size(z,1) > length(u) || size(z,2) > length(v)
184 # Same error as (u*v) .+= z:
185 throw(DimensionMismatch("array could not be broadcast to match destination"))
186 end
187 for d in 3:ndims(z)
188 # Similar error to (u*v) + z:
189 size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
190 axes(z), ", must have singleton at dim ", d)))
191 end
192 (u .* v) .+ z
193 end
194
195 Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
196 muladd(A', x', z')'
197 Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
198 transpose(muladd(transpose(A), transpose(x), transpose(z)))
199
200 function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector})
201 T = promote_type(eltype(A), eltype(y), eltype(z))
202 C = similar(A, T, axes(A,1))
203 C .= z
204 mul!(C, A, y, true, true)
205 end
206
207 function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat})
208 T = promote_type(eltype(A), eltype(B), eltype(z))
209 C = similar(A, T, axes(A,1), axes(B,2))
210 C .= z
211 mul!(C, A, B, true, true)
212 end
213
214 """
215 mul!(Y, A, B) -> Y
216
217 Calculates the matrix-matrix or matrix-vector product ``AB`` and stores the result in `Y`,
218 overwriting the existing value of `Y`. Note that `Y` must not be aliased with either `A` or
219 `B`.
220
221 # Examples
222 ```jldoctest
223 julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; Y = similar(B); mul!(Y, A, B);
224
225 julia> Y
226 2×2 Matrix{Float64}:
227 3.0 3.0
228 7.0 7.0
229 ```
230
231 # Implementation
232 For custom matrix and vector types, it is recommended to implement
233 5-argument `mul!` rather than implementing 3-argument `mul!` directly
234 if possible.
235 """
236 @inline function mul!(C, A, B)
237 31 (11 %)
31 (11 %) samples spent in mul!
31 (100 %) (incl.) when called from * line 113
31 (100 %) samples spent calling mul!
return mul!(C, A, B, true, false)
238 end
239
240 """
241 mul!(C, A, B, α, β) -> C
242
243 Combined inplace matrix-matrix or matrix-vector multiply-add ``A B α + C β``.
244 The result is stored in `C` by overwriting it. Note that `C` must not be
245 aliased with either `A` or `B`.
246
247 !!! compat "Julia 1.3"
248 Five-argument `mul!` requires at least Julia 1.3.
249
250 # Examples
251 ```jldoctest
252 julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; C=[1.0 2.0; 3.0 4.0];
253
254 julia> mul!(C, A, B, 100.0, 10.0) === C
255 true
256
257 julia> C
258 2×2 Matrix{Float64}:
259 310.0 320.0
260 730.0 740.0
261 ```
262 """
263 31 (11 %)
31 (11 %) samples spent in mul!
31 (100 %) (incl.) when called from mul! line 237
31 (100 %) samples spent calling generic_matmatmul!
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
264 generic_matmatmul!(
265 C,
266 wrapper_char(A),
267 wrapper_char(B),
268 _unwrap(A),
269 _unwrap(B),
270 MulAddMul(α, β)
271 )
272
273 """
274 rmul!(A, B)
275
276 Calculate the matrix-matrix product ``AB``, overwriting `A`, and return the result.
277 Here, `B` must be of special matrix type, like, e.g., [`Diagonal`](@ref),
278 [`UpperTriangular`](@ref) or [`LowerTriangular`](@ref), or of some orthogonal type,
279 see [`QR`](@ref).
280
281 # Examples
282 ```jldoctest
283 julia> A = [0 1; 1 0];
284
285 julia> B = UpperTriangular([1 2; 0 3]);
286
287 julia> rmul!(A, B);
288
289 julia> A
290 2×2 Matrix{Int64}:
291 0 3
292 1 2
293
294 julia> A = [1.0 2.0; 3.0 4.0];
295
296 julia> F = qr([0 1; -1 0]);
297
298 julia> rmul!(A, F.Q)
299 2×2 Matrix{Float64}:
300 2.0 1.0
301 4.0 3.0
302 ```
303 """
304 rmul!(A, B)
305
306 """
307 lmul!(A, B)
308
309 Calculate the matrix-matrix product ``AB``, overwriting `B`, and return the result.
310 Here, `A` must be of special matrix type, like, e.g., [`Diagonal`](@ref),
311 [`UpperTriangular`](@ref) or [`LowerTriangular`](@ref), or of some orthogonal type,
312 see [`QR`](@ref).
313
314 # Examples
315 ```jldoctest
316 julia> B = [0 1; 1 0];
317
318 julia> A = UpperTriangular([1 2; 0 3]);
319
320 julia> lmul!(A, B);
321
322 julia> B
323 2×2 Matrix{Int64}:
324 2 1
325 3 0
326
327 julia> B = [1.0 2.0; 3.0 4.0];
328
329 julia> F = qr([0 1; -1 0]);
330
331 julia> lmul!(F.Q, B)
332 2×2 Matrix{Float64}:
333 3.0 4.0
334 1.0 2.0
335 ```
336 """
337 lmul!(A, B)
338
339 # THE one big BLAS dispatch
340 @inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
341 _add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
342 if all(in(('N', 'T', 'C')), (tA, tB))
343 if tA == 'T' && tB == 'N' && A === B
344 31 (11 %)
31 (11 %) samples spent in generic_matmatmul!
31 (100 %) (incl.) when called from mul! line 263
31 (100 %) samples spent calling syrk_wrapper!
return syrk_wrapper!(C, 'T', A, _add)
345 elseif tA == 'N' && tB == 'T' && A === B
346 return syrk_wrapper!(C, 'N', A, _add)
347 elseif tA == 'C' && tB == 'N' && A === B
348 return herk_wrapper!(C, 'C', A, _add)
349 elseif tA == 'N' && tB == 'C' && A === B
350 return herk_wrapper!(C, 'N', A, _add)
351 else
352 return gemm_wrapper!(C, tA, tB, A, B, _add)
353 end
354 end
355 alpha, beta = promote(_add.alpha, _add.beta, zero(T))
356 if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
357 if (tA == 'S' || tA == 's') && tB == 'N'
358 return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
359 elseif (tB == 'S' || tB == 's') && tA == 'N'
360 return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
361 elseif (tA == 'H' || tA == 'h') && tB == 'N'
362 return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
363 elseif (tB == 'H' || tB == 'h') && tA == 'N'
364 return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
365 end
366 end
367 return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
368 end
369
370 # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
371 @inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
372 _add::MulAddMul=MulAddMul()) where {T<:BlasReal}
373 if all(in(('N', 'T', 'C')), (tA, tB))
374 gemm_wrapper!(C, tA, tB, A, B, _add)
375 else
376 _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
377 end
378 end
379
380
381 # Supporting functions for matrix multiplication
382
383 # copy transposed(adjoint) of upper(lower) side-diagonals. Optionally include diagonal.
384 @inline function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false, diag::Bool=false)
385 18 (6 %)
18 (6 %) samples spent in copytri!
18 (100 %) (incl.) when called from syrk_wrapper! line 512
14 (78 %) samples spent calling copytri!
4 (22 %) samples spent calling copytri!
n = checksquare(A)
386 off = diag ? 0 : 1
387 if uplo == 'U'
388 for i = 1:n, j = (i+off):n
389 4 (1 %)
4 (1 %) samples spent in copytri!
4 (100 %) (incl.) when called from copytri! line 385
4 (100 %) samples spent calling setindex!
A[j,i] = conjugate ? adjoint(A[i,j]) : transpose(A[i,j])
390 14 (5 %)
14 (5 %) samples spent in copytri!
14 (100 %) (incl.) when called from copytri! line 385
14 (100 %) samples spent calling iterate
end
391 elseif uplo == 'L'
392 for i = 1:n, j = (i+off):n
393 A[i,j] = conjugate ? adjoint(A[j,i]) : transpose(A[j,i])
394 end
395 else
396 throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
397 end
398 A
399 end
400
401 function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T},
402 α::Number=true, β::Number=false) where {T<:BlasFloat}
403 mA, nA = lapack_size(tA, A)
404 nA != length(x) &&
405 throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
406 mA != length(y) &&
407 throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
408 mA == 0 && return y
409 nA == 0 && return _rmul_or_fill!(y, β)
410 alpha, beta = promote(α, β, zero(T))
411 if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
412 stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
413 !iszero(stride(x, 1)) && # We only check input's stride here.
414 if tA in ('N', 'T', 'C')
415 return BLAS.gemv!(tA, alpha, A, x, beta, y)
416 elseif tA in ('S', 's')
417 return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y)
418 elseif tA in ('H', 'h')
419 return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y)
420 end
421 end
422 if tA in ('S', 's', 'H', 'h')
423 # re-wrap again and use plain ('N') matvec mul algorithm,
424 # because _generic_matvecmul! can't handle the HermOrSym cases specifically
425 return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
426 else
427 return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
428 end
429 end
430
431 function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
432 α::Number = true, β::Number = false) where {T<:BlasReal}
433 mA, nA = lapack_size(tA, A)
434 nA != length(x) &&
435 throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
436 mA != length(y) &&
437 throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
438 mA == 0 && return y
439 nA == 0 && return _rmul_or_fill!(y, β)
440 alpha, beta = promote(α, β, zero(T))
441 if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
442 stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
443 stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
444 !iszero(stride(x, 1))
445 BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
446 return y
447 else
448 Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
449 return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
450 end
451 end
452
453 function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
454 α::Number = true, β::Number = false) where {T<:BlasFloat}
455 mA, nA = lapack_size(tA, A)
456 nA != length(x) &&
457 throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
458 mA != length(y) &&
459 throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
460 mA == 0 && return y
461 nA == 0 && return _rmul_or_fill!(y, β)
462 alpha, beta = promote(α, β, zero(T))
463 @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
464 stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
465 !iszero(stride(x, 1)) && tA in ('N', 'T', 'C')
466 xfl = reinterpret(reshape, T, x) # Use reshape here.
467 yfl = reinterpret(reshape, T, y)
468 BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
469 BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
470 return y
471 elseif tA in ('S', 's', 'H', 'h')
472 # re-wrap again and use plain ('N') matvec mul algorithm,
473 # because _generic_matvecmul! can't handle the HermOrSym cases specifically
474 return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
475 else
476 return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
477 end
478 end
479
480
31 (11 %) samples spent in syrk_wrapper!
31 (100 %) (incl.) when called from generic_matmatmul! line 344
function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
481 _add = MulAddMul()) where {T<:BlasFloat}
482 nC = checksquare(C)
483 if tA == 'T'
484 (nA, mA) = size(A,1), size(A,2)
485 tAt = 'N'
486 else
487 (mA, nA) = size(A,1), size(A,2)
488 tAt = 'T'
489 end
490 if nC != mA
491 throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
492 end
493 if mA == 0 || nA == 0 || iszero(_add.alpha)
494 return _rmul_or_fill!(C, _add.beta)
495 end
496 if mA == 2 && nA == 2
497 return matmul2x2!(C, tA, tAt, A, A, _add)
498 end
499 if mA == 3 && nA == 3
500 return matmul3x3!(C, tA, tAt, A, A, _add)
501 end
502
503 # BLAS.syrk! only updates symmetric C
504 # alternatively, make non-zero β a show-stopper for BLAS.syrk!
505 if iszero(_add.beta) || issymmetric(C)
506 alpha, beta = promote(_add.alpha, _add.beta, zero(T))
507 if (alpha isa Union{Bool,T} &&
508 beta isa Union{Bool,T} &&
509 stride(A, 1) == stride(C, 1) == 1 &&
510 stride(A, 2) >= size(A, 1) &&
511 stride(C, 2) >= size(C, 1))
512 31 (11 %)
18 (58 %) samples spent calling copytri!
13 (42 %) samples spent calling syrk!
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
513 end
514 end
515 return gemm_wrapper!(C, tA, tAt, A, A, _add)
516 end
517
518 function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
519 _add = MulAddMul()) where {T<:BlasReal}
520 nC = checksquare(C)
521 if tA == 'C'
522 (nA, mA) = size(A,1), size(A,2)
523 tAt = 'N'
524 else
525 (mA, nA) = size(A,1), size(A,2)
526 tAt = 'C'
527 end
528 if nC != mA
529 throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)"))
530 end
531 if mA == 0 || nA == 0 || iszero(_add.alpha)
532 return _rmul_or_fill!(C, _add.beta)
533 end
534 if mA == 2 && nA == 2
535 return matmul2x2!(C, tA, tAt, A, A, _add)
536 end
537 if mA == 3 && nA == 3
538 return matmul3x3!(C, tA, tAt, A, A, _add)
539 end
540
541 # Result array does not need to be initialized as long as beta==0
542 # C = Matrix{T}(undef, mA, mA)
543
544 if iszero(_add.beta) || issymmetric(C)
545 alpha, beta = promote(_add.alpha, _add.beta, zero(T))
546 if (alpha isa Union{Bool,T} &&
547 beta isa Union{Bool,T} &&
548 stride(A, 1) == stride(C, 1) == 1 &&
549 stride(A, 2) >= size(A, 1) &&
550 stride(C, 2) >= size(C, 1))
551 return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
552 end
553 end
554 return gemm_wrapper!(C, tA, tAt, A, A, _add)
555 end
556
557 function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
558 A::StridedVecOrMat{T},
559 B::StridedVecOrMat{T}) where {T<:BlasFloat}
560 mA, nA = lapack_size(tA, A)
561 mB, nB = lapack_size(tB, B)
562 C = similar(B, T, mA, nB)
563 if all(in(('N', 'T', 'C')), (tA, tB))
564 gemm_wrapper!(C, tA, tB, A, B)
565 else
566 _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
567 end
568 end
569
570 function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
571 A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
572 _add = MulAddMul()) where {T<:BlasFloat}
573 mA, nA = lapack_size(tA, A)
574 mB, nB = lapack_size(tB, B)
575
576 if nA != mB
577 throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
578 end
579
580 if C === A || B === C
581 throw(ArgumentError("output matrix must not be aliased with input matrix"))
582 end
583
584 if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
585 if size(C) != (mA, nB)
586 throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
587 end
588 return _rmul_or_fill!(C, _add.beta)
589 end
590
591 if mA == 2 && nA == 2 && nB == 2
592 return matmul2x2!(C, tA, tB, A, B, _add)
593 end
594 if mA == 3 && nA == 3 && nB == 3
595 return matmul3x3!(C, tA, tB, A, B, _add)
596 end
597
598 alpha, beta = promote(_add.alpha, _add.beta, zero(T))
599 if (alpha isa Union{Bool,T} &&
600 beta isa Union{Bool,T} &&
601 stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
602 stride(A, 2) >= size(A, 1) &&
603 stride(B, 2) >= size(B, 1) &&
604 stride(C, 2) >= size(C, 1))
605 return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
606 end
607 _generic_matmatmul!(C, tA, tB, A, B, _add)
608 end
609
610 function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
611 A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
612 _add = MulAddMul()) where {T<:BlasReal}
613 mA, nA = lapack_size(tA, A)
614 mB, nB = lapack_size(tB, B)
615
616 if nA != mB
617 throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
618 end
619
620 if C === A || B === C
621 throw(ArgumentError("output matrix must not be aliased with input matrix"))
622 end
623
624 if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
625 if size(C) != (mA, nB)
626 throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
627 end
628 return _rmul_or_fill!(C, _add.beta)
629 end
630
631 if mA == 2 && nA == 2 && nB == 2
632 return matmul2x2!(C, tA, tB, A, B, _add)
633 end
634 if mA == 3 && nA == 3 && nB == 3
635 return matmul3x3!(C, tA, tB, A, B, _add)
636 end
637
638 alpha, beta = promote(_add.alpha, _add.beta, zero(T))
639
640 # Make-sure reinterpret-based optimization is BLAS-compatible.
641 if (alpha isa Union{Bool,T} &&
642 beta isa Union{Bool,T} &&
643 stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
644 stride(A, 2) >= size(A, 1) &&
645 stride(B, 2) >= size(B, 1) &&
646 stride(C, 2) >= size(C, 1) && tA == 'N')
647 BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
648 return C
649 end
650 _generic_matmatmul!(C, tA, tB, A, B, _add)
651 end
652
653 # blas.jl defines matmul for floats; other integer and mixed precision
654 # cases are handled here
655
656 lapack_size(t::AbstractChar, M::AbstractVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1))
657
658 function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
659 if tM == 'N'
660 copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src)
661 else
662 LinearAlgebra.copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src)
663 tM == 'C' && conj!(@view B[ir_dest, jr_dest])
664 end
665 B
666 end
667
668 function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
669 if tM == 'N'
670 LinearAlgebra.copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src)
671 else
672 copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src)
673 tM == 'C' && conj!(@view B[ir_dest, jr_dest])
674 end
675 B
676 end
677
678 # TODO: It will be faster for large matrices to convert to float,
679 # call BLAS, and convert back to required type.
680
681 # NOTE: the generic version is also called as fallback for
682 # strides != 1 cases
683
684 @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
685 _add::MulAddMul = MulAddMul())
686 Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
687 return _generic_matvecmul!(C, ta, Anew, B, _add)
688 end
689
690 function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
691 _add::MulAddMul = MulAddMul())
692 require_one_based_indexing(C, A, B)
693 @assert tA in ('N', 'T', 'C')
694 mB = length(B)
695 mA, nA = lapack_size(tA, A)
696 if mB != nA
697 throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
698 end
699 if mA != length(C)
700 throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
701 end
702
703 Astride = size(A, 1)
704
705 @inbounds begin
706 if tA == 'T' # fastest case
707 if nA == 0
708 for k = 1:mA
709 _modify!(_add, false, C, k)
710 end
711 else
712 for k = 1:mA
713 aoffs = (k-1)*Astride
714 s = zero(A[aoffs + 1]*B[1] + A[aoffs + 1]*B[1])
715 for i = 1:nA
716 s += transpose(A[aoffs+i]) * B[i]
717 end
718 _modify!(_add, s, C, k)
719 end
720 end
721 elseif tA == 'C'
722 if nA == 0
723 for k = 1:mA
724 _modify!(_add, false, C, k)
725 end
726 else
727 for k = 1:mA
728 aoffs = (k-1)*Astride
729 s = zero(A[aoffs + 1]*B[1] + A[aoffs + 1]*B[1])
730 for i = 1:nA
731 s += A[aoffs + i]'B[i]
732 end
733 _modify!(_add, s, C, k)
734 end
735 end
736 else # tA == 'N'
737 for i = 1:mA
738 if !iszero(_add.beta)
739 C[i] *= _add.beta
740 elseif mB == 0
741 C[i] = false
742 else
743 C[i] = zero(A[i]*B[1] + A[i]*B[1])
744 end
745 end
746 for k = 1:mB
747 aoffs = (k-1)*Astride
748 b = _add(B[k])
749 for i = 1:mA
750 C[i] += A[aoffs + i] * b
751 end
752 end
753 end
754 end # @inbounds
755 C
756 end
757
758 function generic_matmatmul(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S}) where {T,S}
759 mA, nA = lapack_size(tA, A)
760 mB, nB = lapack_size(tB, B)
761 C = similar(B, promote_op(matprod, T, S), mA, nB)
762 generic_matmatmul!(C, tA, tB, A, B)
763 end
764
765 const tilebufsize = 10800 # Approximately 32k/3
766
767 function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul)
768 mA, nA = lapack_size(tA, A)
769 mB, nB = lapack_size(tB, B)
770 mC, nC = size(C)
771
772 if iszero(_add.alpha)
773 return _rmul_or_fill!(C, _add.beta)
774 end
775 if mA == nA == mB == nB == mC == nC == 2
776 return matmul2x2!(C, tA, tB, A, B, _add)
777 end
778 if mA == nA == mB == nB == mC == nC == 3
779 return matmul3x3!(C, tA, tB, A, B, _add)
780 end
781 A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
782 B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
783 _generic_matmatmul!(C, tA, tB, A, B, _add)
784 end
785
786 function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
787 _add::MulAddMul) where {T,S,R}
788 @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
789 require_one_based_indexing(C, A, B)
790
791 mA, nA = lapack_size(tA, A)
792 mB, nB = lapack_size(tB, B)
793 if mB != nA
794 throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
795 end
796 if size(C,1) != mA || size(C,2) != nB
797 throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
798 end
799
800 if iszero(_add.alpha) || isempty(A) || isempty(B)
801 return _rmul_or_fill!(C, _add.beta)
802 end
803
804 tile_size = 0
805 if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
806 tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
807 end
808 @inbounds begin
809 if tile_size > 0
810 sz = (tile_size, tile_size)
811 Atile = Array{T}(undef, sz)
812 Btile = Array{S}(undef, sz)
813
814 z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
815 z = convert(promote_type(typeof(z1), R), z1)
816
817 if mA < tile_size && nA < tile_size && nB < tile_size
818 copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
819 copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
820 for j = 1:nB
821 boff = (j-1)*tile_size
822 for i = 1:mA
823 aoff = (i-1)*tile_size
824 s = z
825 for k = 1:nA
826 s += Atile[aoff+k] * Btile[boff+k]
827 end
828 _modify!(_add, s, C, (i,j))
829 end
830 end
831 else
832 Ctile = Array{R}(undef, sz)
833 for jb = 1:tile_size:nB
834 jlim = min(jb+tile_size-1,nB)
835 jlen = jlim-jb+1
836 for ib = 1:tile_size:mA
837 ilim = min(ib+tile_size-1,mA)
838 ilen = ilim-ib+1
839 fill!(Ctile, z)
840 for kb = 1:tile_size:nA
841 klim = min(kb+tile_size-1,mB)
842 klen = klim-kb+1
843 copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
844 copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
845 for j=1:jlen
846 bcoff = (j-1)*tile_size
847 for i = 1:ilen
848 aoff = (i-1)*tile_size
849 s = z
850 for k = 1:klen
851 s += Atile[aoff+k] * Btile[bcoff+k]
852 end
853 Ctile[bcoff+i] += s
854 end
855 end
856 end
857 if isone(_add.alpha) && iszero(_add.beta)
858 copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
859 else
860 C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
861 end
862 end
863 end
864 end
865 else
866 # Multiplication for non-plain-data uses the naive algorithm
867 if tA == 'N'
868 if tB == 'N'
869 for i = 1:mA, j = 1:nB
870 z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
871 Ctmp = convert(promote_type(R, typeof(z2)), z2)
872 for k = 1:nA
873 Ctmp += A[i, k]*B[k, j]
874 end
875 _modify!(_add, Ctmp, C, (i,j))
876 end
877 elseif tB == 'T'
878 for i = 1:mA, j = 1:nB
879 z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
880 Ctmp = convert(promote_type(R, typeof(z2)), z2)
881 for k = 1:nA
882 Ctmp += A[i, k] * transpose(B[j, k])
883 end
884 _modify!(_add, Ctmp, C, (i,j))
885 end
886 else
887 for i = 1:mA, j = 1:nB
888 z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
889 Ctmp = convert(promote_type(R, typeof(z2)), z2)
890 for k = 1:nA
891 Ctmp += A[i, k]*B[j, k]'
892 end
893 _modify!(_add, Ctmp, C, (i,j))
894 end
895 end
896 elseif tA == 'T'
897 if tB == 'N'
898 for i = 1:mA, j = 1:nB
899 z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
900 Ctmp = convert(promote_type(R, typeof(z2)), z2)
901 for k = 1:nA
902 Ctmp += transpose(A[k, i]) * B[k, j]
903 end
904 _modify!(_add, Ctmp, C, (i,j))
905 end
906 elseif tB == 'T'
907 for i = 1:mA, j = 1:nB
908 z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
909 Ctmp = convert(promote_type(R, typeof(z2)), z2)
910 for k = 1:nA
911 Ctmp += transpose(A[k, i]) * transpose(B[j, k])
912 end
913 _modify!(_add, Ctmp, C, (i,j))
914 end
915 else
916 for i = 1:mA, j = 1:nB
917 z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
918 Ctmp = convert(promote_type(R, typeof(z2)), z2)
919 for k = 1:nA
920 Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
921 end
922 _modify!(_add, Ctmp, C, (i,j))
923 end
924 end
925 else
926 if tB == 'N'
927 for i = 1:mA, j = 1:nB
928 z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
929 Ctmp = convert(promote_type(R, typeof(z2)), z2)
930 for k = 1:nA
931 Ctmp += A[k, i]'B[k, j]
932 end
933 _modify!(_add, Ctmp, C, (i,j))
934 end
935 elseif tB == 'T'
936 for i = 1:mA, j = 1:nB
937 z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
938 Ctmp = convert(promote_type(R, typeof(z2)), z2)
939 for k = 1:nA
940 Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
941 end
942 _modify!(_add, Ctmp, C, (i,j))
943 end
944 else
945 for i = 1:mA, j = 1:nB
946 z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
947 Ctmp = convert(promote_type(R, typeof(z2)), z2)
948 for k = 1:nA
949 Ctmp += A[k, i]'B[j, k]'
950 end
951 _modify!(_add, Ctmp, C, (i,j))
952 end
953 end
954 end
955 end
956 end # @inbounds
957 C
958 end
959
960
961 # multiply 2x2 matrices
962 function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
963 matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
964 end
965
966 function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
967 _add::MulAddMul = MulAddMul())
968 require_one_based_indexing(C, A, B)
969 if !(size(A) == size(B) == size(C) == (2,2))
970 throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
971 end
972 @inbounds begin
973 if tA == 'N'
974 A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
975 elseif tA == 'T'
976 # TODO making these lazy could improve perf
977 A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1]))
978 A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2]))
979 elseif tA == 'C'
980 # TODO making these lazy could improve perf
981 A11 = copy(A[1,1]'); A12 = copy(A[2,1]')
982 A21 = copy(A[1,2]'); A22 = copy(A[2,2]')
983 elseif tA == 'S'
984 A11 = symmetric(A[1,1], :U); A12 = A[1,2]
985 A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
986 elseif tA == 's'
987 A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
988 A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
989 elseif tA == 'H'
990 A11 = hermitian(A[1,1], :U); A12 = A[1,2]
991 A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
992 else # if tA == 'h'
993 A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
994 A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
995 end
996 if tB == 'N'
997 B11 = B[1,1]; B12 = B[1,2];
998 B21 = B[2,1]; B22 = B[2,2]
999 elseif tB == 'T'
1000 # TODO making these lazy could improve perf
1001 B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1]))
1002 B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2]))
1003 elseif tB == 'C'
1004 # TODO making these lazy could improve perf
1005 B11 = copy(B[1,1]'); B12 = copy(B[2,1]')
1006 B21 = copy(B[1,2]'); B22 = copy(B[2,2]')
1007 elseif tB == 'S'
1008 B11 = symmetric(B[1,1], :U); B12 = B[1,2]
1009 B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U)
1010 elseif tB == 's'
1011 B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1]))
1012 B21 = B[2,1]; B22 = symmetric(B[2,2], :L)
1013 elseif tB == 'H'
1014 B11 = hermitian(B[1,1], :U); B12 = B[1,2]
1015 B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U)
1016 else # if tB == 'h'
1017 B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1]))
1018 B21 = B[2,1]; B22 = hermitian(B[2,2], :L)
1019 end
1020 _modify!(_add, A11*B11 + A12*B21, C, (1,1))
1021 _modify!(_add, A11*B12 + A12*B22, C, (1,2))
1022 _modify!(_add, A21*B11 + A22*B21, C, (2,1))
1023 _modify!(_add, A21*B12 + A22*B22, C, (2,2))
1024 end # inbounds
1025 C
1026 end
1027
1028 # Multiply 3x3 matrices
1029 function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
1030 matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
1031 end
1032
1033 function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
1034 _add::MulAddMul = MulAddMul())
1035 require_one_based_indexing(C, A, B)
1036 if !(size(A) == size(B) == size(C) == (3,3))
1037 throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
1038 end
1039 @inbounds begin
1040 if tA == 'N'
1041 A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3]
1042 A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3]
1043 A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3]
1044 elseif tA == 'T'
1045 # TODO making these lazy could improve perf
1046 A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
1047 A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2]))
1048 A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3]))
1049 elseif tA == 'C'
1050 # TODO making these lazy could improve perf
1051 A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]')
1052 A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]')
1053 A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]')
1054 elseif tA == 'S'
1055 A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1056 A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
1057 A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
1058 elseif tA == 's'
1059 A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
1060 A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
1061 A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
1062 elseif tA == 'H'
1063 A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1064 A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
1065 A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
1066 else # if tA == 'h'
1067 A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
1068 A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
1069 A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
1070 end
1071
1072 if tB == 'N'
1073 B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3]
1074 B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3]
1075 B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3]
1076 elseif tB == 'T'
1077 # TODO making these lazy could improve perf
1078 B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
1079 B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2]))
1080 B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3]))
1081 elseif tB == 'C'
1082 # TODO making these lazy could improve perf
1083 B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]')
1084 B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]')
1085 B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]')
1086 elseif tB == 'S'
1087 B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
1088 B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3]
1089 B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U)
1090 elseif tB == 's'
1091 B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
1092 B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2]))
1093 B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L)
1094 elseif tB == 'H'
1095 B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
1096 B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3]
1097 B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U)
1098 else # if tB == 'h'
1099 B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1]))
1100 B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2]))
1101 B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L)
1102 end
1103
1104 _modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
1105 _modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
1106 _modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
1107
1108 _modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
1109 _modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
1110 _modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
1111
1112 _modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1))
1113 _modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2))
1114 _modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3))
1115 end # inbounds
1116 C
1117 end
1118
1119 const RealOrComplex = Union{Real,Complex}
1120
1121 # Three-argument *
1122 """
1123 *(A, B::AbstractMatrix, C)
1124 A * B * C * D
1125
1126 Chained multiplication of 3 or 4 matrices is done in the most efficient sequence,
1127 based on the sizes of the arrays. That is, the number of scalar multiplications needed
1128 for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)`
1129 to choose which of these to execute.
1130
1131 If the last factor is a vector, or the first a transposed vector, then it is efficient
1132 to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
1133 for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this
1134 allocates an intermediate array.
1135
1136 If the first or last factor is a number, this will be fused with the matrix
1137 multiplication, using 5-arg [`mul!`](@ref).
1138
1139 See also [`muladd`](@ref), [`dot`](@ref).
1140
1141 !!! compat "Julia 1.7"
1142 These optimisations require at least Julia 1.7.
1143 """
1144 *(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x)
1145
1146 *(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v
1147 *(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v)
1148
1149 *(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ)
1150 *(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ)
1151 *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) =
1152 mat_vec_scalar(B,C,α)
1153 *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) =
1154 mat_mat_scalar(B,C,α)
1155
1156 *(α::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv)
1157 *(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ)
1158 *(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C)
1159
1160 *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C)
1161 *(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C
1162
1163 function _tri_matmul(A,B,C,δ=nothing)
1164 n,m = size(A)
1165 # m,k == size(B)
1166 k,l = size(C)
1167 costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l
1168 costA_BC = m*k*l + n*m*l # m*l + n*l
1169 if costA_BC < costAB_C
1170 isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ)
1171 else
1172 isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ)
1173 end
1174 end
1175
1176 # Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar.
1177
1178 mat_vec_scalar(A, x, γ) = A * (x * γ) # fallback
1179 mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ)
1180 mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ
1181
1182 function _mat_vec_scalar(A, x, γ)
1183 T = promote_type(eltype(A), eltype(x), typeof(γ))
1184 C = similar(A, T, axes(A,1))
1185 mul!(C, A, x, γ, false)
1186 end
1187
1188 mat_mat_scalar(A, B, γ) = (A*B) * γ # fallback
1189 mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) =
1190 _mat_mat_scalar(A, B, γ)
1191
1192 function _mat_mat_scalar(A, B, γ)
1193 T = promote_type(eltype(A), eltype(B), typeof(γ))
1194 C = similar(A, T, axes(A,1), axes(B,2))
1195 mul!(C, A, B, γ, false)
1196 end
1197
1198 mat_mat_scalar(A::AdjointAbsVec, B, γ) = (γ' * (A * B)')' # preserving order, adjoint reverses
1199 mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
1200 mat_vec_scalar(B', A', γ')'
1201
1202 mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose(γ * transpose(A * B))
1203 mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
1204 transpose(mat_vec_scalar(transpose(B), transpose(A), γ))
1205
1206
1207 # Four-argument *, by type
1208 *(α::Number, β::Number, C::AbstractMatrix, x::AbstractVector) = (α*β) * C * x
1209 *(α::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) = (α*β) * C * D
1210 *(α::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
1211 *(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)
1212 *(α::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
1213 (α*vt*C) * D # solves an ambiguity
1214
1215 *(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x * (γ*δ)
1216 *(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B * (γ*δ)
1217 *(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ)
1218 *(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ
1219 *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ
1220
1221 *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
1222 *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
1223 *(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)
1224
1225 # Four-argument *, by size
1226 *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ)
1227 *(α::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
1228 _tri_matmul(B,C,D,α)
1229 *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) =
1230 _quad_matmul(A,B,C,D)
1231
1232 function _quad_matmul(A,B,C,D)
1233 c1 = _mul_cost((A,B),(C,D))
1234 c2 = _mul_cost(((A,B),C),D)
1235 c3 = _mul_cost(A,(B,(C,D)))
1236 c4 = _mul_cost((A,(B,C)),D)
1237 c5 = _mul_cost(A,((B,C),D))
1238 cmin = min(c1,c2,c3,c4,c5)
1239 if c1 == cmin
1240 (A*B) * (C*D)
1241 elseif c2 == cmin
1242 ((A*B) * C) * D
1243 elseif c3 == cmin
1244 A * (B * (C*D))
1245 elseif c4 == cmin
1246 (A * (B*C)) * D
1247 else
1248 A * ((B*C) * D)
1249 end
1250 end
1251 @inline _mul_cost(A::AbstractMatrix) = 0
1252 @inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B)
1253 @inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B)))
1254 @inline _mul_sizes(A::AbstractMatrix) = size(A)
1255 @inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B))