Skip to content

Commit

Permalink
Move reinterpret-based optimization for complex matrix * real vec/m…
Browse files Browse the repository at this point in the history
…at to lower level. (JuliaLang#44052)
  • Loading branch information
N5N3 authored and LilithHafner committed Feb 22, 2022
1 parent de6911f commit 042e7b5
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 76 deletions.
130 changes: 82 additions & 48 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,16 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, 'N', A, x, alpha, beta)

# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32, Float64)
@eval begin
@inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
yfl = reinterpret($elty, y)
mul!(yfl, Afl, x, alpha, beta)
return y
end
end
end
# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec compuation.
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, 'N', A, x, alpha, beta)

# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta)

@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
Expand Down Expand Up @@ -192,18 +185,6 @@ end
(*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A)))
(*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A)))

for elty in (Float32,Float64)
@eval begin
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, B, alpha, beta)
return C
end
end
end

"""
muladd(A, y, z)
Expand Down Expand Up @@ -410,18 +391,14 @@ end
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
for elty in (Float32,Float64)
@eval begin
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, tB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, tB, alpha, beta)
return C
end
end
end
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta))

# collapsing the following two defs with C::AbstractVecOrMat yields ambiguities
@inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
Expand Down Expand Up @@ -513,22 +490,36 @@ end
function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T},
α::Number=true, β::Number=false) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
if nA != length(x)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
end
if mA != length(y)
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
return BLAS.gemv!(tA, alpha, A, x, beta, y)
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
if mA == 0
return y
end
if nA == 0
return _rmul_or_fill!(y, β)
end
end

function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
α::Number = true, β::Number = false) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
return BLAS.gemv!(tA, alpha, A, x, beta, y)
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) &&
stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y`
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
Expand Down Expand Up @@ -681,6 +672,49 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
generic_matmatmul!(C, tA, tB, A, B, _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)

if nA != mB
throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
end

if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end

if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
if size(C) != (mA, nB)
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, _add.beta)
end

if mA == 2 && nA == 2 && nB == 2
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == 3 && nA == 3 && nB == 3
return matmul3x3!(C, tA, tB, A, B, _add)
end

alpha, beta = promote(_add.alpha, _add.beta, zero(T))

# Make-sure reinterpret-based optimization is BLAS-compatible.
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(B, 2) >= size(B, 1) &&
stride(C, 2) >= size(C, 1)) && tA == 'N'
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
# cases are handled here

Expand Down
59 changes: 31 additions & 28 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,34 +226,37 @@ end
end
end

@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32, Float64)
for T2 in (Float32, Float64)
for arg1_real in (true, false)
@testset "Combination $T1 $T2 $arg1_real $arg2_real" for arg2_real in (true, false)
A0 = reshape(Vector{T1}(1:25), 5, 5) .+
(arg1_real ? 0 : 1im * reshape(Vector{T1}(-3:21), 5, 5))
A = view(A0, 1:2, 1:2)
B = Matrix{T2}([1.0 3.0; -1.0 2.0]) .+
(arg2_real ? 0 : 1im * Matrix{T2}([3.0 4; -1 10]))
AB_correct = copy(A) * B
AB = A * B # view times matrix
@test AB AB_correct
A1 = view(A0, :, 1:2) # rectangular view times matrix
@test A1 * B copy(A1) * B
B1 = view(B, 1:2, 1:2)
AB1 = A * B1 # view times view
@test AB1 AB_correct
x = Vector{T2}([1.0; 10.0]) .+ (arg2_real ? 0 : 1im * Vector{T2}([3; -1]))
Ax_exact = copy(A) * x
Ax = A * x # view times vector
@test Ax Ax_exact
x1 = view(x, 1:2)
Ax1 = A * x1 # view times viewed vector
@test Ax1 Ax_exact
@test copy(A) * x1 Ax_exact # matrix times viewed vector
# View times transposed matrix
Bt = transpose(B)
@test A * Bt A * copy(Bt)
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64)
A0 = randn(complex(T), 10, 10)
B0 = randn(T, 10, 10)
@testset "Combination Mat{$(complex(T))} Mat{$T}" for Bax1 in (1:5, 2:2:10), Bax2 in (1:5, 2:2:10)
B = view(A0, Bax1, Bax2)
tB = transpose(B)
Bd, tBd = copy(B), copy(tB)
for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
A = view(A0, Aax1, Aax2)
AB_correct = copy(A) * Bd
AtB_correct = copy(A) * tBd
@test A*Bd AB_correct # view times matrix
@test A*B AB_correct # view times view
@test A*tBd AtB_correct # view times transposed matrix
@test A*tB AtB_correct # view times transposed view
end
end
x = randn(T, 10)
y0 = similar(A0, 20)
@testset "Combination Mat{$(complex(T))} Vec{$T}" for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
A = view(A0, Aax1, Aax2)
Ad = copy(A)
for indx in (1:5, 1:2:10, 6:-1:2)
vx = view(x, indx)
dx = x[indx]
Ax_correct = Ad*dx
@test A*vx A*dx Ad*vx Ax_correct # view/matrix times view/vector
for indy in (1:2:2size(A,1), size(A,1):-1:1)
y = view(y0, indy)
@test mul!(y, A, vx) mul!(y, A, dx) mul!(y, Ad, vx)
mul!(y, Ad, dx) Ax_correct # test for uncontiguous dest
end
end
end
Expand Down

0 comments on commit 042e7b5

Please sign in to comment.