Skip to content

Commit

Permalink
Fix the type instability causing slowdown and extra memory
Browse files Browse the repository at this point in the history
allocation in sparse vcat. Fixes #7926.
  • Loading branch information
ViralBShah committed Feb 15, 2015
1 parent 07f3ee7 commit 8bd5e5e
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1884,48 +1884,61 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end



# Sparse concatenation

function vcat(X::SparseMatrixCSC...)
Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

myvcat(map(x->convert(SparseMatrixCSC{Tv,Ti}, x), X)...)
end

function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
m = sum(mX)
n = nX[1]

for i = 2 : num
if nX[i] != n; throw(DimensionMismatch("")); end
if nX[i] != n
throw(DimensionMismatch("All inputs to vcat should have the same number of columns"))
end
end
m = sum(mX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array(Ti, n + 1)
nnzX = [ nnz(x) for x in X ]
nnz_res = sum(nnzX)
colptr = Array(Ti, n + 1)
rowval = Array(Ti, nnz_res)
nzval = Array(Tv, nnz_res)
nzval = Array(Tv, nnz_res)

colptr[1] = 1
@inbounds for c = 1 : n
for c = 1:n
mX_sofar = 0
rr1 = colptr[c]
ptr_res = colptr[c]
for i = 1 : num
XI = X[i]
rX1 = XI.colptr[c]
rX2 = XI.colptr[c + 1] - 1
rr2 = rr1 + (rX2 - rX1)
Xi = X[i]
colptrXi = Xi.colptr
rowvalXi = Xi.rowval
nzvalXi = Xi.nzval

col_length = (colptrXi[c + 1] - 1) - colptrXi[c]
ptrXi = colptrXi[c]
for k=ptr_res:(ptr_res + col_length)
@inbounds rowval[k] = rowvalXi[ptrXi] + mX_sofar
@inbounds nzval[k] = nzvalXi[ptrXi]
ptrXi += 1
end

rowval[rr1 : rr2] = XI.rowval[rX1 : rX2] .+ mX_sofar
nzval[rr1 : rr2] = XI.nzval[rX1 : rX2]
ptr_res += col_length + 1
mX_sofar += mX[i]
rr1 = rr2 + 1
end
colptr[c + 1] = rr1
colptr[c + 1] = ptr_res
end
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end


function hcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
Expand Down

0 comments on commit 8bd5e5e

Please sign in to comment.