Skip to content

Commit

Permalink
Promote eltype in BlasFloat matrix multiplication (#32587)
Browse files Browse the repository at this point in the history
Co-Authored-By: Fredrik Ekre <[email protected]>
  • Loading branch information
dkarrasch and fredrikekre authored Oct 29, 2019
1 parent 962634d commit 592748a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
12 changes: 11 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasFloat})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), convert(AbstractArray{TS}, A), convert(AbstractArray{TS}, B))
end
function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMatrix{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), convert(AbstractArray{TS}, A), convert(AbstractArray{TS}, B))
end

@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:BlasFloat}
Expand All @@ -162,7 +173,6 @@ end
return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(α, β))
end
end

# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
for elty in (Float32,Float64)
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ end
end
end

@testset "mixed Blas-non-Blas matmul" begin
AA = rand(-10:10,6,6)
BB = rand(Float64,6,6)
CC = zeros(Float64,6,6)
for A in (copy(AA), view(AA, 1:6, 1:6)), B in (copy(BB), view(BB, 1:6, 1:6)), C in (copy(CC), view(CC, 1:6, 1:6))
@test LinearAlgebra.mul!(C, A, B) == A*B
@test LinearAlgebra.mul!(C, transpose(A), transpose(B)) == transpose(A)*transpose(B)
@test LinearAlgebra.mul!(C, A, adjoint(B)) == A*transpose(B)
@test LinearAlgebra.mul!(C, adjoint(A), B) == transpose(A)*B
end
end

@testset "matrix algebra with subarrays of floats (stride != 1)" begin
A = reshape(map(Float64,1:20),5,4)
Aref = A[1:2:end,1:2:end]
Expand Down

0 comments on commit 592748a

Please sign in to comment.