-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse vector/matrix: add fast implementation of find_next and find_prev (fixed) #23317
Changes from 9 commits
d45df8d
14f443a
92558be
4b54020
5dde4af
daee267
1ac4141
fe4b76e
132ff27
85bc773
a33abbe
db62ae4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,3 +28,24 @@ function Base.reinterpret(::Type, A::AbstractSparseArray) | |
Try reinterpreting the value itself instead. | ||
""") | ||
end | ||
|
||
# The following two methods should be overloaded by concrete types to avoid | ||
# allocating the I = find(...) | ||
_sparse_findnextnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0) | ||
_sparse_findprevnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedlast(I, i); n>0 ? I[n] : 0) | ||
|
||
function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Int) | ||
j = _sparse_findnextnz(v, i) | ||
while j != 0 && !f(v[j]) | ||
j = _sparse_findnextnz(v, j+1) | ||
end | ||
return j | ||
end | ||
|
||
function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int) | ||
j = _sparse_findprevnz(v, i) | ||
while j != 0 && !f(v[j]) | ||
j = _sparse_findprevnz(v, j-1) | ||
end | ||
return j | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly? :) function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
j = _sparse_findprevnz(v, i)
while j != 0 && iszero(v[j])
j = _sparse_findprevnz(v, j-1)
end
return j
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -737,6 +737,24 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti} | |
return (I, V) | ||
end | ||
|
||
function _sparse_findnextnz(v::SparseVector, i::Int) | ||
n = searchsortedfirst(v.nzind, i) | ||
if n > length(v.nzind) | ||
return 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For type stability, |
||
else | ||
return v.nzind[n] | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A compact alternative: function _sparse_findnextnz(v::SparseVector, i::Int)
n = searchsortedfirst(v.nzind, i)
return n <= length(v.nzind) ? v.nzind[n] : 0
end |
||
|
||
function _sparse_findprevnz(v::SparseVector, i::Int) | ||
n = searchsortedlast(v.nzind, i) | ||
if n < 1 | ||
return 0 | ||
else | ||
return v.nzind[n] | ||
end | ||
end | ||
|
||
### Generic functions operating on AbstractSparseVector | ||
|
||
### getindex | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2148,3 +2148,34 @@ end | |
# count should run only over S.nzval[1:nnz(S)], not S.nzval in full | ||
@test count(SparseMatrixCSC(2, 2, Int[1, 2, 3], Int[1, 2], Bool[true, true, true])) == 2 | ||
end | ||
|
||
@testset "sparse findprev/findnext operations" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to include some sparse test arrays with stored zeros |
||
|
||
x = [0,0,0,0,1,0,1,0,1,1,0] | ||
x_sp = sparse(x) | ||
|
||
for i=1:length(x) | ||
@test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i) | ||
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i) | ||
end | ||
|
||
y = [0 0 0 0 0; | ||
1 0 1 0 0; | ||
1 0 0 0 1; | ||
0 0 1 0 0; | ||
1 0 1 1 0] | ||
y_sp = sparse(y) | ||
|
||
for i=1:length(y) | ||
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i) | ||
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i) | ||
end | ||
|
||
z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1)) | ||
z = collect(z_sp) | ||
|
||
for i=1:length(z) | ||
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i) | ||
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the following? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! In the interest of expediency, feel free to commit+merge that if you have time.