Skip to content

Commit

Permalink
Fix: Complex SubArray times real Matrix (#29246)
Browse files Browse the repository at this point in the history
* gemm_wrapper! -> mul! (#29224)

* testing for #29224

* code review update: More tests

* Complex times real reinterpret trick fix for vectors and transposed matrices
  • Loading branch information
jarlebring authored and andreasnoack committed Sep 25, 2018
1 parent 94aa39b commit 82503cd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
11 changes: 7 additions & 4 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ end
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B

mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where {T<:BlasFloat} = gemv!(y, 'N', A, x)
# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32,Float64)
@eval begin
function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A)
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
mul!(yfl,Afl,x)
return y
end
end
Expand Down Expand Up @@ -141,12 +142,14 @@ function (*)(A::AbstractMatrix, B::AbstractMatrix)
mul!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = gemm_wrapper!(C, 'N', 'N', A, B)
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
for elty in (Float32,Float64)
@eval begin
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
mul!(Cfl, Afl, B)
return C
end
end
Expand Down Expand Up @@ -234,13 +237,13 @@ function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:An
return gemm_wrapper!(C, 'N', 'T', A, B)
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
for elty in (Float32,Float64)
@eval begin
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}})
B = transB.parent
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
mul!(Cfl,Afl,transB)
return C
end
end
Expand Down
34 changes: 34 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,40 @@ end
@test *(Asub, adjoint(Asub)) == *(Aref, adjoint(Aref))
end

@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32,Float64)
for T2 in (Float32,Float64)
for arg1_real in (true,false)
@testset "Combination $T1 $T2 $arg1_real $arg2_real" for arg2_real in (true,false)
A0 = reshape(Vector{T1}(1:25),5,5) .+
(arg1_real ? 0 : 1im*reshape(Vector{T1}(-3:21),5,5))
A = view(A0,1:2,1:2)
B = Matrix{T2}([1.0 3.0; -1.0 2.0]).+
(arg2_real ? 0 : 1im*Matrix{T2}([3.0 4; -1 10]))
AB_correct = copy(A)*B
AB = A*B; # view times matrix
@test AB AB_correct
A1 = view(A0,:,1:2) # rectangular view times matrix
@test A1*B copy(A1)*B
B1 = view(B,1:2,1:2);
AB1 = A*B1; # view times view
@test AB1 AB_correct
x = Vector{T2}([1.0;10.0]) .+ (arg2_real ? 0 : 1im*Vector{T2}([3;-1]))
Ax_exact = copy(A)*x
Ax = A*x # view times vector
@test Ax Ax_exact
x1 = view(x,1:2)
Ax1 = A*x1 # view times viewed vector
@test Ax1 Ax_exact
@test copy(A)*x1 Ax_exact # matrix times viewed vector
# View times transposed matrix
Bt = transpose(B);
@test A*Bt A*copy(Bt)
end
end
end
end


@testset "issue #15286" begin
A = reshape(map(Float64, 1:20), 5, 4)
C = zeros(8, 8)
Expand Down

0 comments on commit 82503cd

Please sign in to comment.