Skip to content

Commit

Permalink
Improved the dot product between two vectors and a sparse matrix (#410)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
albertomercurio and dkarrasch authored Sep 9, 2023
1 parent 2fae1a1 commit c93065c
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,28 +343,26 @@ function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2}
return r
end

function dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector)
function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractVector{T3}) where {T1,T2,T3}
require_one_based_indexing(x, y)
m, n = size(A)
(length(x) == m && n == length(y)) || throw(DimensionMismatch())
if iszero(m) || iszero(n)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
T = promote_type(eltype(x), eltype(A), eltype(y))
r = zero(T)
rvals = getrowval(A)
s = dot(zero(T1), zero(T2), zero(T3))
T = typeof(s)
(iszero(m) || iszero(n)) && return s

rowvals = getrowval(A)
nzvals = getnzval(A)
@inbounds for col in 1:n

@inbounds @simd for col in 1:n
ycol = y[col]
if _isnotzero(ycol)
temp = zero(T)
for k in nzrange(A, col)
temp += adjoint(x[rvals[k]]) * nzvals[k]
end
r += temp * ycol
for j in nzrange(A, col)
row = rowvals[j]
val = nzvals[j]
s += dot(x[row], val, ycol)
end
end
return r
return s
end
function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector)
m, n = size(A)
Expand Down

0 comments on commit c93065c

Please sign in to comment.