diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 73dc5ebaf7fad..85f1ea7c8db14 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1883,17 +1883,9 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto A end - # Sparse concatenation function vcat(X::SparseMatrixCSC...) - Tv = promote_type(map(x->eltype(x.nzval), X)...) - Ti = promote_type(map(x->eltype(x.rowval), X)...) - - vcat(map(x->convert(SparseMatrixCSC{Tv,Ti}, x), X)...) -end - -function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...) num = length(X) mX = [ size(x, 1) for x in X ] nX = [ size(x, 2) for x in X ] @@ -1906,6 +1898,13 @@ function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...) end end + Tv = eltype(X[1].nzval) + Ti = eltype(X[1].rowval) + for i = 2:length(X) + Tv = promote_type(Tv, eltype(X[i].nzval)) + Ti = promote_type(Ti, eltype(X[i].rowval)) + end + nnzX = [ nnz(x) for x in X ] nnz_res = sum(nnzX) colptr = Array(Ti, n + 1) @@ -1917,18 +1916,12 @@ function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...) mX_sofar = 0 ptr_res = colptr[c] for i = 1 : num - Xi = X[i] - colptrXi = Xi.colptr - rowvalXi = Xi.rowval - nzvalXi = Xi.nzval - + colptrXi = X[i].colptr col_length = (colptrXi[c + 1] - 1) - colptrXi[c] - ptrXi = colptrXi[c] - for k=ptr_res:(ptr_res + col_length) - @inbounds rowval[k] = rowvalXi[ptrXi] + mX_sofar - @inbounds nzval[k] = nzvalXi[ptrXi] - ptrXi += 1 - end + ptr_Xi = colptrXi[c] + + stuffcol!(X[i], colptr, rowval, nzval, + ptr_res, ptr_Xi, col_length, mX_sofar) ptr_res += col_length + 1 mX_sofar += mX[i] @@ -1938,6 +1931,19 @@ function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...) SparseMatrixCSC(m, n, colptr, rowval, nzval) end +@inline function stuffcol!(Xi::SparseMatrixCSC, colptr, rowval, nzval, + ptr_res, ptr_Xi, col_length, mX_sofar) + colptrXi = Xi.colptr + rowvalXi = Xi.rowval + nzvalXi = Xi.nzval + + for k=ptr_res:(ptr_res + col_length) + @inbounds rowval[k] = rowvalXi[ptr_Xi] + mX_sofar + @inbounds nzval[k] = nzvalXi[ptr_Xi] + ptr_Xi += 1 + end +end + function hcat(X::SparseMatrixCSC...) num = length(X)