Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse vector/matrix: add fast implementation of find_next and find_prev (fixed) #23317

Merged
merged 12 commits into from
Jan 6, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions base/sparse/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,24 @@ function Base.reinterpret(::Type, A::AbstractSparseArray)
Try reinterpreting the value itself instead.
""")
end

# The following two methods should be overloaded by concrete types to avoid
# allocating the I = find(...)
_sparse_findnextnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0)
_sparse_findprevnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedlast(I, i); n>0 ? I[n] : 0)

function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
j = _sparse_findnextnz(v, i)
while j != 0 && !f(v[j])
j = _sparse_findnextnz(v, j+1)
end
return j
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the following? :)

function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
    j = _sparse_findnextnz(v, i)
    while j != 0 && iszero(v[j])
        j = _sparse_findnextnz(v, j+1)
    end
    return j
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! In the interest of expediency, feel free to commit+merge that if you have time.


function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
j = _sparse_findprevnz(v, i)
while j != 0 && !f(v[j])
j = _sparse_findprevnz(v, j-1)
end
return j
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly? :)

function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
    j = _sparse_findprevnz(v, i)
    while j != 0 && iszero(v[j])
        j = _sparse_findprevnz(v, j-1)
    end
    return j
end

6 changes: 3 additions & 3 deletions base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!, A_rdiv_B!, A_rdiv_Bc!
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,
cotd, coth, count, csc, cscd, csch, adjoint!, diag, diff, done, dot, eig,
exp10, exp2, eye, findn, floor, hash, indmin, inv, issymmetric, istril, istriu,
log10, log2, lu, next, sec, secd, sech, show, sin,
sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
exp10, exp2, eye, findn, findprev, findnext, floor, hash, indmin, inv,
issymmetric, istril, istriu, log10, log2, lu, next, sec, secd, sech, show,
sin, sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
tand, tanh, trace, transpose!, tril!, triu!, trunc, vecnorm, abs, abs2,
broadcast, ceil, complex, cond, conj, convert, copy, copy!, adjoint, diagm,
exp, expm1, factorize, find, findmax, findmin, findnz, float, getindex,
Expand Down
36 changes: 36 additions & 0 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,42 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
return (I, J, V)
end

function _sparse_findnextnz(m::SparseMatrixCSC, i::Int)
if i > length(m)
return 0
end
row, col = ind2sub(m, i)
lo, hi = m.colptr[col], m.colptr[col+1]
n = searchsortedfirst(m.rowval, row, lo, hi-1, Base.Order.Forward)
if lo <= n <= hi-1
return sub2ind(m, m.rowval[n], col)
end
nextcol = findnext(c->(c>hi), m.colptr, col+1)
if nextcol == 0
return 0
end
nextlo = m.colptr[nextcol-1]
return sub2ind(m, m.rowval[nextlo], nextcol-1)
end

function _sparse_findprevnz(m::SparseMatrixCSC, i::Int)
if i < 1
return 0
end
row, col = ind2sub(m, i)
lo, hi = m.colptr[col], m.colptr[col+1]
n = searchsortedlast(m.rowval, row, lo, hi-1, Base.Order.Forward)
if lo <= n <= hi-1
return sub2ind(m, m.rowval[n], col)
end
prevcol = findprev(c->(c<lo), m.colptr, col-1)
if prevcol == 0
return 0
end
prevhi = m.colptr[prevcol+1]
return sub2ind(m, m.rowval[prevhi-1], prevcol)
end

import Base.Random.GLOBAL_RNG
function sprand_IJ(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloat)
((m < 0) || (n < 0)) && throw(ArgumentError("invalid Array dimensions"))
Expand Down
18 changes: 18 additions & 0 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,24 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti}
return (I, V)
end

function _sparse_findnextnz(v::SparseVector, i::Int)
n = searchsortedfirst(v.nzind, i)
if n > length(v.nzind)
return 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For type stability, return zero(indtype(v))? (Likewise below.)

else
return v.nzind[n]
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A compact alternative:

function _sparse_findnextnz(v::SparseVector, i::Int)
    n = searchsortedfirst(v.nzind, i)
    return n <= length(v.nzind) ? v.nzind[n] : 0
end


function _sparse_findprevnz(v::SparseVector, i::Int)
n = searchsortedlast(v.nzind, i)
if n < 1
return 0
else
return v.nzind[n]
end
end

### Generic functions operating on AbstractSparseVector

### getindex
Expand Down
31 changes: 31 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2148,3 +2148,34 @@ end
# count should run only over S.nzval[1:nnz(S)], not S.nzval in full
@test count(SparseMatrixCSC(2, 2, Int[1, 2, 3], Int[1, 2], Bool[true, true, true])) == 2
end

@testset "sparse findprev/findnext operations" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to include some sparse test arrays with stored zeros


x = [0,0,0,0,1,0,1,0,1,1,0]
x_sp = sparse(x)

for i=1:length(x)
@test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i)
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i)
end

y = [0 0 0 0 0;
1 0 1 0 0;
1 0 0 0 1;
0 0 1 0 0;
1 0 1 1 0]
y_sp = sparse(y)

for i=1:length(y)
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i)
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i)
end

z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1))
z = collect(z_sp)

for i=1:length(z)
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end
end