Skip to content

Commit

Permalink
Optimize real matrix * complex vector by reintrepreting as real (Juli…
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored and LilithHafner committed Feb 22, 2022
1 parent 85d35b3 commit 6625dad
Show file tree
Hide file tree
Showing 2 changed files with 397 additions and 331 deletions.
43 changes: 36 additions & 7 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Matrix-matrix multiplication

AdjOrTransStridedMat{T} = Union{Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

# matmul.jl: Everything to do with dense matrix multiplication

matprod(x, y) = x*y + x*y
Expand Down Expand Up @@ -59,18 +64,26 @@ end
@inline mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, 'N', A, x, alpha, beta)

# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32,Float64)
for elty in (Float32, Float64)
@eval begin
@inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Real, beta::Real)
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
yfl = reinterpret($elty, y)
mul!(yfl, Afl, x, alpha, beta)
return y
end
end
end

# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta)

@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, 'N', A, x, MulAddMul(alpha, beta))
Expand Down Expand Up @@ -113,11 +126,6 @@ end
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
(*)(x::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A)*transpose(x))

# Matrix-matrix multiplication

AdjOrTransStridedMat{T} = Union{Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)
Expand Down Expand Up @@ -525,6 +533,27 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
end
end

function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
α::Number = true, β::Number = false) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
nC = checksquare(C)
Expand Down
Loading

0 comments on commit 6625dad

Please sign in to comment.