Skip to content

Commit

Permalink
Add zero stride-check to LinearAlgebra.gemv!
Browse files Browse the repository at this point in the history
Also call BLAS for negative `lda` (if possible)
  • Loading branch information
N5N3 committed Feb 9, 2022
1 parent 6d4f8b9 commit 6f7daea
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
12 changes: 8 additions & 4 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) # We only check input's stride here.
return BLAS.gemv!(tA, alpha, A, x, beta, y)
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
Expand All @@ -516,8 +517,9 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) &&
stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y`
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
!iszero(stride(x, 1))
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
Expand All @@ -535,7 +537,9 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1))
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ end
end
end

@testset "matrix x vector with negtive lda or 0 stride" for T in (Float32, Float64)
for TA in (T, complex(T)), TB in (T, complex(T))
A = view(randn(TA, 10, 10), 1:10, 10:-1:1) # negative lda
v = view([randn(TB)], 1 .+ 0(1:10)) # 0 stride
Ad, vd = copy(A), copy(v)
@test Ad * vd A * vd Ad * v A * v
end
end

@testset "issue #15286" begin
A = reshape(map(Float64, 1:20), 5, 4)
C = zeros(8, 8)
Expand Down

0 comments on commit 6f7daea

Please sign in to comment.