Skip to content

Commit

Permalink
Fix vcat of sparse vectors with numbers (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Sep 13, 2022
1 parent d88be9f commit ead48fe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1143,21 +1143,28 @@ const _Annotated_SparseConcatArrays = Union{_Triangular_SparseConcatArrays, _Sym
const _SparseConcatGroup = Union{_DenseConcatGroup, _SparseConcatArrays, _Annotated_SparseConcatArrays}

# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays

# the output array type is determined by the first element of the to be concatenated objects
# if this is a Number, the output would be dense by the fallback abstractarray.jl code (see cat_similar)
# so make sure that if that happens, the "array" is sparse (if more sparse arrays are involved, of course)
_sparse(x::Number) = sparsevec([1], [x], 1)
_sparse(A) = _makesparse(A)
_makesparse(x::Number) = x
_makesparse(x::AbstractArray) = SparseMatrixCSC(issparse(x) ? x : sparse(x))
_makesparse(x::AbstractVector) = convert(SparseVector, issparse(x) ? x : sparse(x))::SparseVector
_makesparse(x::AbstractMatrix) = convert(SparseMatrixCSC, issparse(x) ? x : sparse(x))::SparseMatrixCSC

# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
Base.@constprop :aggressive function Base._cat(dims, Xin::_SparseConcatGroup...)
X = map(_makesparse, Xin)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
T = promote_eltype(Xin...)
return Base._cat_t(dims, T, X...)
end
function hcat(Xin::_SparseConcatGroup...)
X = map(_makesparse, Xin)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
return cat(X..., dims=Val(2))
end
function vcat(Xin::_SparseConcatGroup...)
X = map(_makesparse, Xin)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
return cat(X..., dims=Val(1))
end
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
Expand Down Expand Up @@ -1191,9 +1198,9 @@ Concatenate along dimension 2. Return a SparseMatrixCSC object.
the concatenation with specialized "sparse" matrix types from LinearAlgebra.jl
automatically yielded sparse output even in the absence of any SparseArray argument.
"""
sparse_hcat(Xin::Union{AbstractVecOrMat,Number}...) = cat(map(_makesparse, Xin)..., dims=Val(2))
sparse_hcat(Xin::Union{AbstractVecOrMat,Number}...) = cat(_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))..., dims=Val(2))
function sparse_hcat(X::Union{AbstractVecOrMat,UniformScaling,Number}...)
LinearAlgebra._hcat(X...; array_type = SparseMatrixCSC)
LinearAlgebra._hcat(_sparse(first(X)), map(_makesparse, Base.tail(X))...; array_type = SparseMatrixCSC)
end

"""
Expand All @@ -1206,9 +1213,9 @@ Concatenate along dimension 1. Return a SparseMatrixCSC object.
the concatenation with specialized "sparse" matrix types from LinearAlgebra.jl
automatically yielded sparse output even in the absence of any SparseArray argument.
"""
sparse_vcat(Xin::Union{AbstractVecOrMat,Number}...) = cat(map(_makesparse, Xin)..., dims=Val(1))
sparse_vcat(Xin::Union{AbstractVecOrMat,Number}...) = cat(_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))..., dims=Val(1))
function sparse_vcat(X::Union{AbstractVecOrMat,UniformScaling,Number}...)
LinearAlgebra._vcat(X...; array_type = SparseMatrixCSC)
LinearAlgebra._vcat(_sparse(first(X)), map(_makesparse, Base.tail(X))...; array_type = SparseMatrixCSC)
end

"""
Expand All @@ -1224,10 +1231,10 @@ arguments to concatenate in each block row.
automatically yielded sparse output even in the absence of any SparseArray argument.
"""
function sparse_hvcat(rows::Tuple{Vararg{Int}}, Xin::Union{AbstractVecOrMat,Number}...)
hvcat(rows, map(_makesparse, Xin)...)
hvcat(rows, _sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
end
function sparse_hvcat(rows::Tuple{Vararg{Int}}, X::Union{AbstractVecOrMat,UniformScaling,Number}...)
LinearAlgebra._hvcat(rows, X...; array_type = SparseMatrixCSC)
LinearAlgebra._hvcat(rows, _sparse(first(X)), map(_makesparse, Base.tail(X))...; array_type = SparseMatrixCSC)
end

### math functions
Expand Down
16 changes: 16 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,22 @@ end
@test length(V) == m * n
Vr = vec(Hr)
@test Array(V) == Vr
Vnum = vcat(A..., zero(Float64))
Vnum2 = sparse_vcat(map(Array, A)..., zero(Float64))
@test Vnum isa SparseVector{Float64,Int}
@test Vnum2 isa SparseVector{Float64,Int}
@test length(Vnum) == length(Vnum2) == m*n + 1
@test Array(Vnum) == Array(Vnum2) == [Vr; 0]
Vnum = vcat(zero(Float64), A...)
Vnum2 = sparse_vcat(zero(Float64), map(Array, A)...)
@test Vnum isa SparseVector{Float64,Int}
@test Vnum2 isa SparseVector{Float64,Int}
@test length(Vnum) == length(Vnum2) == m*n + 1
@test Array(Vnum) == Array(Vnum2) == [0; Vr]
# case with rowwise a Number as first element, should still yield a sparse matrix
x = sparsevec([1], [3.0], 1)
X = [3.0 x; 3.0 x]
@test issparse(X)
end

@testset "concatenation of sparse vectors with other types" begin
Expand Down

0 comments on commit ead48fe

Please sign in to comment.