diff --git a/src/linalg.jl b/src/linalg.jl index f315300f..9e44de17 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -343,28 +343,26 @@ function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2} return r end -function dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector) +function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractVector{T3}) where {T1,T2,T3} require_one_based_indexing(x, y) m, n = size(A) (length(x) == m && n == length(y)) || throw(DimensionMismatch()) - if iszero(m) || iszero(n) - return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) - end - T = promote_type(eltype(x), eltype(A), eltype(y)) - r = zero(T) - rvals = getrowval(A) + s = dot(zero(T1), zero(T2), zero(T3)) + T = typeof(s) + (iszero(m) || iszero(n)) && return s + + rowvals = getrowval(A) nzvals = getnzval(A) - @inbounds for col in 1:n + + @inbounds @simd for col in 1:n ycol = y[col] - if _isnotzero(ycol) - temp = zero(T) - for k in nzrange(A, col) - temp += adjoint(x[rvals[k]]) * nzvals[k] - end - r += temp * ycol + for j in nzrange(A, col) + row = rowvals[j] + val = nzvals[j] + s += dot(x[row], val, ycol) end end - return r + return s end function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector) m, n = size(A)