Skip to content

Commit

Permalink
sparse vector views remain sparse (#416)
Browse files Browse the repository at this point in the history
Co-authored-by: Marco Cognetta <[email protected]>
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
3 people authored Sep 9, 2023
1 parent c93065c commit ada9edd
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
55 changes: 54 additions & 1 deletion src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ const SparseVectorView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseVector{Tv,Ti},Tup
const SparseVectorUnion{Tv,Ti} = Union{AbstractCompressedVector{Tv,Ti}, SparseColumnView{Tv,Ti}, SparseVectorView{Tv,Ti}}
const AdjOrTransSparseVectorUnion{Tv,Ti} = LinearAlgebra.AdjOrTrans{Tv, <:SparseVectorUnion{Tv,Ti}}

# allows for views of a subset of the sparse vector indices
const SparseVectorPartialView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseVector{Tv,Ti},<:Tuple{AbstractUnitRange},false}

### Basic properties

length(x::SparseVector) = getfield(x, :n)
Expand All @@ -114,6 +117,7 @@ function nnz(x::SparseColumnView)
return length(nzrange(parent(x), colidx))
end
nnz(x::SparseVectorView) = nnz(x.parent)
nnz(x::SparseVectorPartialView) = length(nonzeroinds(x))

"""
nzrange(x::SparseVectorUnion, col)
Expand All @@ -135,6 +139,12 @@ function nonzeros(x::SparseColumnView)
end
nonzeros(x::SparseVectorView) = nonzeros(parent(x))

function nonzeros(x::SparseVectorPartialView)
(first_idx, last_idx) = _partialview_end_indices(x)
nzvals = nonzeros(parent(x))
return view(nzvals, first_idx:last_idx)
end

nonzeroinds(x::SparseVector) = getfield(x, :nzind)
nonzeroinds(x::FixedSparseVector) = getfield(x, :nzind)
function nonzeroinds(x::SparseColumnView)
Expand All @@ -145,10 +155,37 @@ function nonzeroinds(x::SparseColumnView)
end
nonzeroinds(x::SparseVectorView) = nonzeroinds(parent(x))

# return the first and last nonzero indices of the parent that belong to the view
# return end+1:end if no nonzero in the parent
function _partialview_end_indices(x::SparseVectorPartialView)
p = parent(x)
nzinds = nonzeroinds(p)
if isempty(nzinds)
last_idx = length(nzinds)
first_idx = last_idx + 1
else
first_idx = findfirst(>=(x.indices[1][begin]), nzinds)
last_idx = findlast(<=(x.indices[1][end]), nzinds)
# empty view
if first_idx === nothing || last_idx === nothing
last_idx = length(nzinds)
first_idx = last_idx+1
end
end
return (first_idx, last_idx)
end

function nonzeroinds(x::SparseVectorPartialView)
isempty(x.indices[1]) && return indtype(parent(x))[]
(first_idx, last_idx) = _partialview_end_indices(x)
nzinds = nonzeroinds(parent(x))
return @view(nzinds[first_idx:last_idx]) .- (x.indices[1][begin] - 1)
end

rowvals(x::SparseVectorUnion) = nonzeroinds(x)

indtype(x::SparseColumnView) = indtype(parent(x))
indtype(x::SparseVectorView) = indtype(parent(x))
indtype(x::Union{SparseVectorView, SparseVectorPartialView}) = indtype(parent(x))


function Base.sizehint!(v::SparseVector, newlen::Integer)
Expand Down Expand Up @@ -1573,6 +1610,22 @@ for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
end
end

for fun in (:+, :-)
@eval @propagate_inbounds function $(fun)(x::Union{SparseVectorUnion{Tx},SparseVectorPartialView{Tx}}, y::Union{SparseVectorUnion{Ty},SparseVectorPartialView{Ty}}) where {Tx, Ty}
@boundscheck axes(x) == axes(y) || throw(DimensionMismatch("$(axes(x)), $(axes(y))"))
T = promote_type(Tx, Ty)
res = spzeros(T, length(x))
copyto!(res, x)
nzinds = nonzeroinds(y)
nzvals = nonzeros(y)
@inbounds for nzidx in eachindex(nzinds)
res[nzinds[nzidx]] = $fun(res[nzinds[nzidx]], nzvals[nzidx])
end
dropzeros!(res)
return res
end
end

### Reduction
Base.reducedim_initarray(A::SparseVectorUnion, region, v0, ::Type{R}) where {R} =
fill!(Array{R}(undef, Base.to_shape(Base.reduced_indices(A, region))), v0)
Expand Down
13 changes: 13 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,19 @@ end
end
end

@testset "Adding sparse-backed SymTridiagonal (#46355)" begin
a = SymTridiagonal(sparsevec(Int[1]), sparsevec(Int[]))
@test a + a == Matrix(a) + Matrix(a)

# symtridiagonal with non-empty off-diagonal
b = SymTridiagonal(sparsevec(Int[1, 2, 3]), sparsevec(Int[1, 2]))
@test b + b == Matrix(b) + Matrix(b)

# a symtridiagonal with an additional off-diagonal element
c = SymTridiagonal(sparsevec(Int[1, 2, 3]), sparsevec(Int[1, 2, 3]))
@test c + c == Matrix(c) + Matrix(c)
end

@testset "kronecker product" begin
for (m,n) in ((5,10), (13,8))
a = sprand(m, 5, 0.4); a_d = Matrix(a)
Expand Down
29 changes: 29 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,35 @@ spv_x2 = SparseVector(8, [1, 2, 6, 7], [3.25, 4.0, -5.5, -6.0])
SparseVector(8, Int[1, 2, 6], Float64[3.25, 4.0, 3.5]))
@test exact_equal(min.(x, x2),
SparseVector(8, Int[2, 5, 6, 7], Float64[1.25, -0.75, -5.5, -6.0]))

end

test_vectors = [
(spv_x1, spv_x2),
(SparseVector(5, [1], [1.0]), SparseVector(5, [1], [3.0])),
(SparseVector(5, [2], [1.0]), SparseVector(5, [1], [3.0])),
(SparseVector(5, [1], [1.0]), SparseVector(5, [2], [3.0])),
(SparseVector(5, [3], [2.0]), SparseVector(5, [4], [3.0])),
(SparseVector(5, Int[], Float64[]), SparseVector(5, [4], [3.0])),
]
@testset "View operations $((collect(xa), collect(xb))), op $op" for (xa, xb) in test_vectors, op in (-, +)
r1 = op(@view(xa[1:end]), @view(xb[1:end]))
@test r1 == op(xa, xb)
@test r1 isa SparseVector
r2 = op(@view(xa[2:end-1]), @view(xb[2:end-1]))
@test r2 == op(xa, xb)[2:end-1]
@test r2 isa SparseVector
r3 = op(@view(xb[1:end]), big.(xa))
@test r3 == big.(op(xb, xa))
@test r3 isa SparseVector
# empty views
r4 = op(@view(xa[1:2]), @view(xb[1:2]))
@test r4 == op(xa, xb)[1:2]
@test r4 == @view(op(xa, xb)[1:2])
@test r4 isa SparseVector
r5 = op(@view(xa[end-1:end]), @view(xb[end-1:end]))
@test r5 == op(xa, xb)[end-1:end]
@test r5 isa SparseVector
end

### Complex
Expand Down

0 comments on commit ada9edd

Please sign in to comment.