Skip to content

Commit

Permalink
Merge pull request #10206 from JuliaLang/vs/spcat
Browse files Browse the repository at this point in the history
Improve sparse vcat
  • Loading branch information
ViralBShah committed Feb 16, 2015
2 parents 9ef1720 + 11b9872 commit 491c9b6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
55 changes: 37 additions & 18 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1883,49 +1883,68 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
A
end



# Sparse concatenation

function vcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
m = sum(mX)
n = nX[1]

for i = 2 : num
if nX[i] != n; throw(DimensionMismatch("")); end
if nX[i] != n
throw(DimensionMismatch("All inputs to vcat should have the same number of columns"))
end
end
m = sum(mX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)
Tv = eltype(X[1].nzval)
Ti = eltype(X[1].rowval)
for i = 2:length(X)
Tv = promote_type(Tv, eltype(X[i].nzval))
Ti = promote_type(Ti, eltype(X[i].rowval))
end

colptr = Array(Ti, n + 1)
nnzX = [ nnz(x) for x in X ]
nnz_res = sum(nnzX)
colptr = Array(Ti, n + 1)
rowval = Array(Ti, nnz_res)
nzval = Array(Tv, nnz_res)
nzval = Array(Tv, nnz_res)

colptr[1] = 1
@inbounds for c = 1 : n
for c = 1:n
mX_sofar = 0
rr1 = colptr[c]
ptr_res = colptr[c]
for i = 1 : num
XI = X[i]
rX1 = XI.colptr[c]
rX2 = XI.colptr[c + 1] - 1
rr2 = rr1 + (rX2 - rX1)
colptrXi = X[i].colptr
col_length = (colptrXi[c + 1] - 1) - colptrXi[c]
ptr_Xi = colptrXi[c]

rowval[rr1 : rr2] = XI.rowval[rX1 : rX2] .+ mX_sofar
nzval[rr1 : rr2] = XI.nzval[rX1 : rX2]
stuffcol!(X[i], colptr, rowval, nzval,
ptr_res, ptr_Xi, col_length, mX_sofar)

ptr_res += col_length + 1
mX_sofar += mX[i]
rr1 = rr2 + 1
end
colptr[c + 1] = rr1
colptr[c + 1] = ptr_res
end
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

@inline function stuffcol!(Xi::SparseMatrixCSC, colptr, rowval, nzval,
ptr_res, ptr_Xi, col_length, mX_sofar)
colptrXi = Xi.colptr
rowvalXi = Xi.rowval
nzvalXi = Xi.nzval

for k=ptr_res:(ptr_res + col_length)
@inbounds rowval[k] = rowvalXi[ptr_Xi] + mX_sofar
@inbounds nzval[k] = nzvalXi[ptr_Xi]
ptr_Xi += 1
end
end


function hcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
Expand Down
2 changes: 2 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ do33 = ones(3)

# check vert concatenation
@test all([se33; se33] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)))
se33_32bit = convert(SparseMatrixCSC{Float32,Int32}, se33)
@test all([se33; se33_32bit] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)))

# check h+v concatenation
se44 = speye(4)
Expand Down

0 comments on commit 491c9b6

Please sign in to comment.