Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse vector views remain sparse #416

Merged
merged 32 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6b0353d
sparse views behaviour
matbesancon Jul 21, 2023
db3a59f
fix function
matbesancon Jul 21, 2023
1bb83cf
supertype on SparseVectorView
matbesancon Jul 21, 2023
11b6546
fix fun call
matbesancon Jul 21, 2023
3db1e8b
fix subtyping
matbesancon Jul 21, 2023
41461ae
fix interpolation
matbesancon Jul 21, 2023
1ae7634
test type
matbesancon Jul 21, 2023
643cc7c
fix view
matbesancon Jul 21, 2023
6d9e835
dropzeros! to keep behavior
matbesancon Jul 22, 2023
b9c9324
indexing
matbesancon Jul 22, 2023
eb1d53a
fix behavior
matbesancon Jul 22, 2023
aa87a65
more tests
matbesancon Jul 22, 2023
7af3b32
formatted tests, fixed indexing when empty
matbesancon Jul 22, 2023
5136cc0
fix indexing
matbesancon Jul 22, 2023
4dd4a6a
fixed last index
matbesancon Jul 22, 2023
796f058
check axes
matbesancon Jul 24, 2023
ac68812
throw on mismatch
matbesancon Jul 26, 2023
c8f2c46
Merge branch 'main' into sparse-views
SobhanMP Aug 2, 2023
0566470
bikeshedding
matbesancon Aug 5, 2023
896b37b
fix view
matbesancon Aug 5, 2023
db4b77a
added test on empty vec
matbesancon Aug 5, 2023
9e26345
cat: ensure vararg is more inferrable
vtjnash Aug 1, 2023
c2acf22
fix inference of SparseVector cat
vtjnash Aug 1, 2023
ed41a51
test: restore ambiguous test
vtjnash Aug 1, 2023
96b9126
Test suite: activate a temp project if we need to install Aqua.jl dur…
DilumAluthge Aug 23, 2023
8a792e5
Respect `IOContext` while displaying a `SparseMatrixCSC` (#423)
jishnub Aug 24, 2023
39548e6
Cleanup reloaded (#426)
matbesancon Aug 24, 2023
acbfaf8
Fix docs conflict when building as part of full Julia docs (#430)
ViralBShah Aug 26, 2023
df73d70
faster cat performance (#432)
vtjnash Sep 1, 2023
df1d2cb
Merge branch 'main' into sparse-views
dkarrasch Sep 9, 2023
71c7174
add symtridiagonal tests
dkarrasch Sep 9, 2023
536d9fe
fix nonzeroinds for empty views
dkarrasch Sep 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ const SparseVectorView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseVector{Tv,Ti},Tup
const SparseVectorUnion{Tv,Ti} = Union{AbstractCompressedVector{Tv,Ti}, SparseColumnView{Tv,Ti}, SparseVectorView{Tv,Ti}}
const AdjOrTransSparseVectorUnion{Tv,Ti} = LinearAlgebra.AdjOrTrans{Tv, <:SparseVectorUnion{Tv,Ti}}

# allows for views of a subset of the sparse vector indices
const SparseVectorPartialView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseVector{Tv,Ti},<:Tuple{AbstractUnitRange},false}

### Basic properties

length(x::SparseVector) = getfield(x, :n)
Expand All @@ -114,6 +117,7 @@ function nnz(x::SparseColumnView)
return length(nzrange(parent(x), colidx))
end
nnz(x::SparseVectorView) = nnz(x.parent)
nnz(x::SparseVectorPartialView) = length(nonzeroinds(x))

"""
nzrange(x::SparseVectorUnion, col)
Expand All @@ -135,6 +139,12 @@ function nonzeros(x::SparseColumnView)
end
nonzeros(x::SparseVectorView) = nonzeros(parent(x))

function nonzeros(x::SparseVectorPartialView)
(first_idx, last_idx) = _partialview_end_indices(x)
nzvals = nonzeros(parent(x))
return view(nzvals, first_idx:last_idx)
end

nonzeroinds(x::SparseVector) = getfield(x, :nzind)
nonzeroinds(x::FixedSparseVector) = getfield(x, :nzind)
function nonzeroinds(x::SparseColumnView)
Expand All @@ -145,10 +155,37 @@ function nonzeroinds(x::SparseColumnView)
end
nonzeroinds(x::SparseVectorView) = nonzeroinds(parent(x))

# return the first and last nonzero indices of the parent that belong to the view
# return end+1:end if no nonzero in the parent
function _partialview_end_indices(x::SparseVectorPartialView)
p = parent(x)
nzinds = nonzeroinds(p)
if isempty(nzinds)
last_idx = length(nzinds)
first_idx = last_idx + 1
else
first_idx = findfirst(>=(x.indices[1][begin]), nzinds)
last_idx = findlast(<=(x.indices[1][end]), nzinds)
# empty view
if first_idx === nothing || last_idx === nothing
last_idx = length(nzinds)
first_idx = last_idx+1
end
end
return (first_idx, last_idx)
end

function nonzeroinds(x::SparseVectorPartialView)
isempty(x.indices[1]) && return indtype(parent(x))[]
(first_idx, last_idx) = _partialview_end_indices(x)
nzinds = nonzeroinds(parent(x))
return @view(nzinds[first_idx:last_idx]) .- (x.indices[1][begin] - 1)
end

rowvals(x::SparseVectorUnion) = nonzeroinds(x)

indtype(x::SparseColumnView) = indtype(parent(x))
indtype(x::SparseVectorView) = indtype(parent(x))
indtype(x::Union{SparseVectorView, SparseVectorPartialView}) = indtype(parent(x))


function Base.sizehint!(v::SparseVector, newlen::Integer)
Expand Down Expand Up @@ -1573,6 +1610,22 @@ for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
end
end

for fun in (:+, :-)
@eval @propagate_inbounds function $(fun)(x::Union{SparseVectorUnion{Tx},SparseVectorPartialView{Tx}}, y::Union{SparseVectorUnion{Ty},SparseVectorPartialView{Ty}}) where {Tx, Ty}
@boundscheck axes(x) == axes(y) || throw(DimensionMismatch("$(axes(x)), $(axes(y))"))
T = promote_type(Tx, Ty)
res = spzeros(T, length(x))
copyto!(res, x)
nzinds = nonzeroinds(y)
nzvals = nonzeros(y)
@inbounds for nzidx in eachindex(nzinds)
res[nzinds[nzidx]] = $fun(res[nzinds[nzidx]], nzvals[nzidx])
end
dropzeros!(res)
return res
end
end

### Reduction
Base.reducedim_initarray(A::SparseVectorUnion, region, v0, ::Type{R}) where {R} =
fill!(Array{R}(undef, Base.to_shape(Base.reduced_indices(A, region))), v0)
Expand Down
13 changes: 13 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,19 @@ end
end
end

@testset "Adding sparse-backed SymTridiagonal (#46355)" begin
a = SymTridiagonal(sparsevec(Int[1]), sparsevec(Int[]))
@test a + a == Matrix(a) + Matrix(a)

# symtridiagonal with non-empty off-diagonal
b = SymTridiagonal(sparsevec(Int[1, 2, 3]), sparsevec(Int[1, 2]))
@test b + b == Matrix(b) + Matrix(b)

# a symtridiagonal with an additional off-diagonal element
c = SymTridiagonal(sparsevec(Int[1, 2, 3]), sparsevec(Int[1, 2, 3]))
@test c + c == Matrix(c) + Matrix(c)
end

@testset "kronecker product" begin
for (m,n) in ((5,10), (13,8))
a = sprand(m, 5, 0.4); a_d = Matrix(a)
Expand Down
29 changes: 29 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,35 @@ spv_x2 = SparseVector(8, [1, 2, 6, 7], [3.25, 4.0, -5.5, -6.0])
SparseVector(8, Int[1, 2, 6], Float64[3.25, 4.0, 3.5]))
@test exact_equal(min.(x, x2),
SparseVector(8, Int[2, 5, 6, 7], Float64[1.25, -0.75, -5.5, -6.0]))

end

test_vectors = [
(spv_x1, spv_x2),
(SparseVector(5, [1], [1.0]), SparseVector(5, [1], [3.0])),
(SparseVector(5, [2], [1.0]), SparseVector(5, [1], [3.0])),
(SparseVector(5, [1], [1.0]), SparseVector(5, [2], [3.0])),
(SparseVector(5, [3], [2.0]), SparseVector(5, [4], [3.0])),
(SparseVector(5, Int[], Float64[]), SparseVector(5, [4], [3.0])),
]
@testset "View operations $((collect(xa), collect(xb))), op $op" for (xa, xb) in test_vectors, op in (-, +)
r1 = op(@view(xa[1:end]), @view(xb[1:end]))
@test r1 == op(xa, xb)
@test r1 isa SparseVector
r2 = op(@view(xa[2:end-1]), @view(xb[2:end-1]))
@test r2 == op(xa, xb)[2:end-1]
@test r2 isa SparseVector
r3 = op(@view(xb[1:end]), big.(xa))
@test r3 == big.(op(xb, xa))
@test r3 isa SparseVector
# empty views
r4 = op(@view(xa[1:2]), @view(xb[1:2]))
@test r4 == op(xa, xb)[1:2]
@test r4 == @view(op(xa, xb)[1:2])
@test r4 isa SparseVector
r5 = op(@view(xa[end-1:end]), @view(xb[end-1:end]))
@test r5 == op(xa, xb)[end-1:end]
@test r5 isa SparseVector
end

### Complex
Expand Down