Skip to content

Commit

Permalink
Added sparse*diagonal, diagonal*sparse methods and tests. Fixes #14416.
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt committed Dec 22, 2015
1 parent f0e2ec1 commit 0212afc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
28 changes: 28 additions & 0 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,34 @@ function (*){TX,TvA,TiA}(X::StridedMatrix{TX}, A::SparseMatrixCSC{TvA,TiA})
Y
end

function (*){TvA,TiA}(X::Diagonal, A::SparseMatrixCSC{TvA,TiA})
mX, nX = size(X)
if nX != A.m
throw(DimensionMismatch("second dimension of X, $nX, must match first dimension of A, $(A.m)"))
end
Ynzval = zeros(promote_type(eltype(X),TvA), length(A.nzval))
rowval = A.rowval
nzval = A.nzval
@inbounds for k=1:length(nzval)
Ynzval[k] = X[rowval[k],rowval[k]] * nzval[k]
end
SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, Ynzval)
end

function (*){TvA,TiA}(A::SparseMatrixCSC{TvA,TiA}, X::Diagonal)
mX, nX = size(X)
if mX != A.n
throw(DimensionMismatch("second dimension of A, $(A.n), must match first dimension of X, $mX"))
end
Ynzval = zeros(promote_type(eltype(X),TvA), length(A.nzval))
rowval = A.rowval
nzval = A.nzval
@inbounds for col = 1:A.n, k=A.colptr[col]:(A.colptr[col+1]-1)
Ynzval[k] += X[col, col] * nzval[k]
end
SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, Ynzval)
end

# Sparse matrix multiplication as described in [Gustavson, 1978]:
# http://dl.acm.org/citation.cfm?id=355796

Expand Down
3 changes: 3 additions & 0 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ for i = 1:5
c = sparse(rand(Float32,5,5))
d = sparse(rand(Float64,5,5))
@test full(kron(c,d)) == kron(full(c),full(d))
f = Diagonal(rand(5))
@test full(a*f) == full(a)*f
@test full(f*b) == f*full(b)
end

# scale and scale!
Expand Down

0 comments on commit 0212afc

Please sign in to comment.