From 042e7b5183311340b8211b589bef0818e99b01b1 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 9 Feb 2022 17:11:36 +0800 Subject: [PATCH] Move `reinterpret`-based optimization for complex matrix * real vec/mat to lower level. (#44052) --- stdlib/LinearAlgebra/src/matmul.jl | 130 ++++++++++++++++++---------- stdlib/LinearAlgebra/test/matmul.jl | 59 +++++++------ 2 files changed, 113 insertions(+), 76 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index cb7bb9f74bdbf2..f27a3a768b8669 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -65,23 +65,16 @@ end alpha::Number, beta::Number) where {T<:BlasFloat} = gemv!(y, 'N', A, x, alpha, beta) -# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation. -for elty in (Float32, Float64) - @eval begin - @inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty}, - alpha::Real, beta::Real) - Afl = reinterpret($elty, A) - yfl = reinterpret($elty, y) - mul!(yfl, Afl, x, alpha, beta) - return y - end - end -end +# Complex matrix times real vector. +# Reinterpret the matrix as a real matrix and do real matvec compuation. +@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, + alpha::Number, beta::Number) where {T<:BlasReal} = + gemv!(y, 'N', A, x, alpha, beta) # Real matrix times complex vector. # Multiply the matrix with the real and imaginary parts separately @inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}}, - alpha::Number, beta::Number) where {T<:BlasFloat} = + alpha::Number, beta::Number) where {T<:BlasReal} = gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta) @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, @@ -192,18 +185,6 @@ end (*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A))) (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A))) -for elty in (Float32,Float64) - @eval begin - @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty}, - alpha::Real, beta::Real) - Afl = reinterpret($elty, A) - Cfl = reinterpret($elty, C) - mul!(Cfl, Afl, B, alpha, beta) - return C - end - end -end - """ muladd(A, y, z) @@ -410,18 +391,14 @@ end return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta)) end end -# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency. -for elty in (Float32,Float64) - @eval begin - @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, tB::Transpose{<:Any,<:StridedVecOrMat{$elty}}, - alpha::Real, beta::Real) - Afl = reinterpret($elty, A) - Cfl = reinterpret($elty, C) - mul!(Cfl, Afl, tB, alpha, beta) - return C - end - end -end +# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. +@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + alpha::Number, beta::Number) where {T<:BlasReal} = + gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) +@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, + alpha::Number, beta::Number) where {T<:BlasReal} = + gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta)) + # collapsing the following two defs with C::AbstractVecOrMat yields ambiguities @inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @@ -513,22 +490,36 @@ end function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T}, α::Number=true, β::Number=false) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) - if nA != length(x) + nA != length(x) && throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) - end - if mA != length(y) + mA != length(y) && throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) + mA == 0 && return y + nA == 0 && return _rmul_or_fill!(y, β) + alpha, beta = promote(α, β, zero(T)) + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && + stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) + return BLAS.gemv!(tA, alpha, A, x, beta, y) + else + return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end - if mA == 0 - return y - end - if nA == 0 - return _rmul_or_fill!(y, β) - end +end +function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, + α::Number = true, β::Number = false) where {T<:BlasReal} + mA, nA = lapack_size(tA, A) + nA != length(x) && + throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) + mA != length(y) && + throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) + mA == 0 && return y + nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) - return BLAS.gemv!(tA, alpha, A, x, beta, y) + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && + stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) && + stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y` + BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) + return y else return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -681,6 +672,49 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar generic_matmatmul!(C, tA, tB, A, B, _add) end +function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + _add = MulAddMul()) where {T<:BlasReal} + mA, nA = lapack_size(tA, A) + mB, nB = lapack_size(tB, B) + + if nA != mB + throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) + end + + if C === A || B === C + throw(ArgumentError("output matrix must not be aliased with input matrix")) + end + + if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) + if size(C) != (mA, nB) + throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) + end + return _rmul_or_fill!(C, _add.beta) + end + + if mA == 2 && nA == 2 && nB == 2 + return matmul2x2!(C, tA, tB, A, B, _add) + end + if mA == 3 && nA == 3 && nB == 3 + return matmul3x3!(C, tA, tB, A, B, _add) + end + + alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + + # Make-sure reinterpret-based optimization is BLAS-compatible. + if (alpha isa Union{Bool,T} && + beta isa Union{Bool,T} && + stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && + stride(A, 2) >= size(A, 1) && + stride(B, 2) >= size(B, 1) && + stride(C, 2) >= size(C, 1)) && tA == 'N' + BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) + return C + end + generic_matmatmul!(C, tA, tB, A, B, _add) +end + # blas.jl defines matmul for floats; other integer and mixed precision # cases are handled here diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index c53ba15d7c2d60..1c482f8cae97ac 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -226,34 +226,37 @@ end end end -@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32, Float64) - for T2 in (Float32, Float64) - for arg1_real in (true, false) - @testset "Combination $T1 $T2 $arg1_real $arg2_real" for arg2_real in (true, false) - A0 = reshape(Vector{T1}(1:25), 5, 5) .+ - (arg1_real ? 0 : 1im * reshape(Vector{T1}(-3:21), 5, 5)) - A = view(A0, 1:2, 1:2) - B = Matrix{T2}([1.0 3.0; -1.0 2.0]) .+ - (arg2_real ? 0 : 1im * Matrix{T2}([3.0 4; -1 10])) - AB_correct = copy(A) * B - AB = A * B # view times matrix - @test AB ≈ AB_correct - A1 = view(A0, :, 1:2) # rectangular view times matrix - @test A1 * B ≈ copy(A1) * B - B1 = view(B, 1:2, 1:2) - AB1 = A * B1 # view times view - @test AB1 ≈ AB_correct - x = Vector{T2}([1.0; 10.0]) .+ (arg2_real ? 0 : 1im * Vector{T2}([3; -1])) - Ax_exact = copy(A) * x - Ax = A * x # view times vector - @test Ax ≈ Ax_exact - x1 = view(x, 1:2) - Ax1 = A * x1 # view times viewed vector - @test Ax1 ≈ Ax_exact - @test copy(A) * x1 ≈ Ax_exact # matrix times viewed vector - # View times transposed matrix - Bt = transpose(B) - @test A * Bt ≈ A * copy(Bt) +@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64) + A0 = randn(complex(T), 10, 10) + B0 = randn(T, 10, 10) + @testset "Combination Mat{$(complex(T))} Mat{$T}" for Bax1 in (1:5, 2:2:10), Bax2 in (1:5, 2:2:10) + B = view(A0, Bax1, Bax2) + tB = transpose(B) + Bd, tBd = copy(B), copy(tB) + for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10) + A = view(A0, Aax1, Aax2) + AB_correct = copy(A) * Bd + AtB_correct = copy(A) * tBd + @test A*Bd ≈ AB_correct # view times matrix + @test A*B ≈ AB_correct # view times view + @test A*tBd ≈ AtB_correct # view times transposed matrix + @test A*tB ≈ AtB_correct # view times transposed view + end + end + x = randn(T, 10) + y0 = similar(A0, 20) + @testset "Combination Mat{$(complex(T))} Vec{$T}" for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10) + A = view(A0, Aax1, Aax2) + Ad = copy(A) + for indx in (1:5, 1:2:10, 6:-1:2) + vx = view(x, indx) + dx = x[indx] + Ax_correct = Ad*dx + @test A*vx ≈ A*dx ≈ Ad*vx ≈ Ax_correct # view/matrix times view/vector + for indy in (1:2:2size(A,1), size(A,1):-1:1) + y = view(y0, indy) + @test mul!(y, A, vx) ≈ mul!(y, A, dx) ≈ mul!(y, Ad, vx) ≈ + mul!(y, Ad, dx) ≈ Ax_correct # test for uncontiguous dest end end end