Skip to content

Commit

Permalink
Improved dense-sparse matrix multiplication kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha0 committed Oct 8, 2017
1 parent fce2d68 commit cf44d90
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 10 deletions.
130 changes: 120 additions & 10 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

import Base.LinAlg: checksquare
using Base: @propagate_inbounds

## Functions to switch to 0-based indexing to call external sparse solvers

Expand Down Expand Up @@ -41,8 +42,10 @@ function sppromote(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) whe
A, B
end

# In matrix-vector multiplication, the correct orientation of the vector is assumed.

### sparse-dense matrix multiplication: A[c|t]_mul_B[c|t][!]([dense,] sparse, dense)

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
for (f, op, transp) in ((:A_mul_B, :identity, false),
(:Ac_mul_B, :adjoint, true),
(:At_mul_B, :transpose, true))
Expand Down Expand Up @@ -98,17 +101,124 @@ Ac_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B
At_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)


function (*)(X::StridedMatrix{TX}, A::SparseMatrixCSC{TvA,TiA}) where {TX,TvA,TiA}
mX, nX = size(X)
nX == A.m || throw(DimensionMismatch())
Y = zeros(promote_type(TX,TvA), mX, A.n)
rowval = A.rowval
nzval = A.nzval
@inbounds for multivec_row=1:mX, col = 1:A.n, k=A.colptr[col]:(A.colptr[col+1]-1)
Y[multivec_row, col] += X[multivec_row, rowval[k]] * nzval[k]
### dense-sparse matrix multiplication: A[c|t]_mul_B[c|t][!]([dense,] dense, sparse)

# */A_mul_B[!]([dense,] dense, sparse)
function *(A::StridedMatrix, B::SparseMatrixCSC)
@boundscheck size(A, 2) == size(B, 1) || throw(DimensionMismatch())
C = zeros(promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2))
return unchecked_A_mul_B!(C, A, B)
end
function A_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC)
@boundscheck size(A, 2) == size(B, 1) || throw(DimensionMismatch())
@boundscheck size(C, 1) == size(A, 1) || throw(DimensionMismatch())
@boundscheck size(C, 2) == size(B, 2) || throw(DimensionMismatch())
return unchecked_A_mul_B!(C, A, B)
end
function unchecked_A_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC)
mA = size(A, 1)
nB = size(B, 2)
fill!(C, zero(eltype(C)))
@inbounds for jB in 1:nB
for kB in B.colptr[jB]:(B.colptr[jB+1] - 1)
iB = B.rowval[kB]
xB = B.nzval[kB]
@simd for iA in 1:mA
C[iA, jB] = muladd(A[iA, iB], xB, C[iA, jB])
end
end
end
return C
end

# A_mul_B(c|t)[!]([dense,] dense, sparse)
@propagate_inbounds A_mul_Bt(A::StridedMatrix, B::SparseMatrixCSC) = _A_mul_Bq(A, B, identity)
@propagate_inbounds A_mul_Bc(A::StridedMatrix, B::SparseMatrixCSC) = _A_mul_Bq(A, B, conj)
function _A_mul_Bq(A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 2) == size(B, 2) || throw(DimensionMismatch())
C = zeros(promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 1))
return unchecked_A_mul_Bq!(C, A, B, op)
end
@propagate_inbounds A_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _A_mul_Bq!(C, A, B, identity)
@propagate_inbounds A_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _A_mul_Bq!(C, A, B, conj)
function _A_mul_Bq!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 2) == size(B, 2) || throw(DimensionMismatch())
@boundscheck size(C, 1) == size(A, 1) || throw(DimensionMismatch())
@boundscheck size(C, 2) == size(B, 1) || throw(DimensionMismatch())
return unchecked_A_mul_Bq!(C, A, B, op)
end
function unchecked_A_mul_Bq!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
mA = size(A, 1)
nB = size(B, 2)
fill!(C, zero(eltype(C)))
@inbounds for jB in 1:nB
for kB in B.colptr[jB]:(B.colptr[jB+1] - 1)
iB = B.rowval[kB]
qxB = op(B.nzval[kB])
@simd for iA in 1:mA
C[iA, iB] = muladd(A[iA, jB], qxB, C[iA, iB])
end
end
end
Y
return C
end

# A(c|t)_mul_B[!]([dense,] dense, sparse)
@propagate_inbounds At_mul_B(A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_B(A, B, identity)
@propagate_inbounds Ac_mul_B(A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_B(A, B, conj)
function _Aq_mul_B(A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 1) == size(B, 1) || throw(DimensionMismatch())
C = zeros(promote_type(eltype(A), eltype(B)), size(A, 2), size(B, 2))
return unchecked_Aq_mul_B!(C, A, B, op)
end
@propagate_inbounds At_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_B!(C, A, B, identity)
@propagate_inbounds Ac_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_B!(C, A, B, conj)
function _Aq_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 1) == size(B, 1) || throw(DimensionMismatch())
@boundscheck size(C, 1) == size(A, 2) || throw(DimensionMismatch())
@boundscheck size(C, 2) == size(B, 2) || throw(DimensionMismatch())
return unchecked_Aq_mul_B!(C, A, B, op)
end
function unchecked_Aq_mul_B!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
# Without some additional storage into which to reorder data prior to performing multiplication,
# memory access patterns for this operation aren't particularly happy. For now, perform the
# operation in two steps --- an explicit adjoint/transpose followed by an efficient multiplication.
# For the future, test the various possible access patterns to determine whether any best this
# approach all around, and if so replace this implementation.
return twostep_Aq_mul_B!(C, A, B, op)
end
twostep_Aq_mul_B!(C, A, B, ::typeof(identity)) = unchecked_A_mul_B!(C, transpose(A), B)
twostep_Aq_mul_B!(C, A, B, ::typeof(conj)) = unchecked_A_mul_B!(C, adjoint(A), B)

# A(t|c)_mul_B(t|c)[!]([dense,] dense, sparse)
@propagate_inbounds At_mul_Bt(A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_Bq(A, B, identity)
@propagate_inbounds Ac_mul_Bc(A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_Bq(A, B, conj)
function _Aq_mul_Bq(A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 1) == size(B, 2) || throw(DimensionMismatch())
C = zeros(promote_type(eltype(A), eltype(B)), size(A, 2), size(B, 1))
return unchecked_Aq_mul_Bq!(C, A, B, op)
end
@propagate_inbounds At_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_Bq!(C, A, B, identity)
@propagate_inbounds Ac_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC) = _Aq_mul_Bq!(C, A, B, conj)
function _Aq_mul_Bq!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
@boundscheck size(A, 1) == size(B, 2) || throw(DimensionMismatch())
@boundscheck size(C, 1) == size(A, 2) || throw(DimensionMismatch())
@boundscheck size(C, 2) == size(B, 1) || throw(DimensionMismatch())
return unchecked_Aq_mul_Bq!(C, A, B, op)
end
function unchecked_Aq_mul_Bq!(C::StridedMatrix, A::StridedMatrix, B::SparseMatrixCSC, op::TF) where TF
# Without some additional storage into which to reorder data prior to performing multiplication,
# memory access patterns for this operation aren't particularly happy. For now, perform the
# operation in two steps --- a relatively efficient multiplication in reverse order
# followed by a transposition. For the future, test the various possible access patterns
# to determine whether any best this approach all around, and if so replace this implementation.
return twostep_Aq_mul_Bq!(C, A, B, op)
end
twostep_Aq_mul_Bq!(C, A, B, ::typeof(identity)) = transpose!(C, *(B, A))
twostep_Aq_mul_Bq!(C, A, B, ::typeof(conj)) = adjoint!(C, *(B, A))


### other mixed sparse-?/?-sparse matrix multiplication

function (*)(D::Diagonal, A::SparseMatrixCSC)
T = Base.promote_op(*, eltype(D), eltype(A))
Expand Down
52 changes: 52 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2041,3 +2041,55 @@ end
@test isfinite.(cov_sparse) == isfinite.(cov_dense)
end
end

@testset "dense-sparse matrix multiplication" begin
using Base.LinAlg: *, A_mul_B!,
A_mul_Bt, A_mul_Bt!, A_mul_Bc, A_mul_Bc!,
At_mul_B, At_mul_B!, Ac_mul_B, Ac_mul_B!,
At_mul_Bt, At_mul_Bt!, Ac_mul_Bc, Ac_mul_Bc!
# out-of-place dense-sparse ops, i.e. A[t|c]_mul_B[t|c](dense, sparse)
#
# exercise kernels, which are shared with corresponding in-place ops
for (m, k, n) in ((5, 5, 5), (5, 10, 15), (15, 10, 5))
densemat = rand(Complex{Float64}, m, k)
sparsemat = sprand(Complex{Float64}, k, n, 0.4)
tdensemat = transpose(densemat)
tsparsemat = transpose(sparsemat)
@test *(densemat, sparsemat) *(densemat, Matrix(sparsemat))
@test A_mul_Bt(densemat, tsparsemat) A_mul_Bt(densemat, Matrix(tsparsemat))
@test A_mul_Bc(densemat, tsparsemat) A_mul_Bc(densemat, Matrix(tsparsemat))
@test At_mul_B(tdensemat, sparsemat) At_mul_B(tdensemat, Matrix(sparsemat))
@test Ac_mul_B(tdensemat, sparsemat) Ac_mul_B(tdensemat, Matrix(sparsemat))
@test At_mul_Bt(tdensemat, tsparsemat) At_mul_Bt(tdensemat, Matrix(tsparsemat))
end
# exercise inner-dimensions-match checks
n, x = 3, 4
Cnn, Cxn, Cnx = zeros(n, n), zeros(x, n), zeros(n, x)
Ann, Axn, Anx = zeros(n, n), zeros(x, n), spzeros(n, x)
Snn, Sxn, Snx = spzeros(n, n), spzeros(x, n), spzeros(n, x)
@test_throws DimensionMismatch (*)(Ann, Sxn)
@test_throws DimensionMismatch A_mul_Bt(Ann, Snx)
@test_throws DimensionMismatch A_mul_Bc(Ann, Snx)
@test_throws DimensionMismatch At_mul_B(Axn, Snn)
@test_throws DimensionMismatch Ac_mul_B(Axn, Snn)
@test_throws DimensionMismatch At_mul_Bt(Ann, Snx)
@test_throws DimensionMismatch Ac_mul_Bc(Ann, Snx)

# in-place dense-sparse ops, i.e. A[t|c]_mul_B[t|c]!(dense, dense, sparse)
# the kernels were exercised through the out-of-place calls above,
# so below exercise only the entry points (shape checks)
#
# exercise matmul outer-dimensions-match checks
for op! in (A_mul_B!, A_mul_Bt!, A_mul_Bc!, At_mul_B!, Ac_mul_B!, At_mul_Bt!, Ac_mul_Bc!)
@test_throws DimensionMismatch op!(Cxn, Ann, Snn)
@test_throws DimensionMismatch op!(Cnx, Ann, Snn)
end
# exercise matul inner-dimensions-match checks
@test_throws DimensionMismatch A_mul_B!(Cnn, Ann, Sxn)
@test_throws DimensionMismatch A_mul_Bt!(Cnn, Ann, Snx)
@test_throws DimensionMismatch A_mul_Bc!(Cnn, Ann, Snx)
@test_throws DimensionMismatch At_mul_B!(Cnn, Axn, Snn)
@test_throws DimensionMismatch Ac_mul_B!(Cnn, Axn, Snn)
@test_throws DimensionMismatch At_mul_Bt!(Cnn, Ann, Snx)
@test_throws DimensionMismatch Ac_mul_Bc!(Cnn, Ann, Snx)
end

0 comments on commit cf44d90

Please sign in to comment.