Skip to content

Commit

Permalink
Support each iteration over whole sparse array
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed May 15, 2016
1 parent ee82348 commit 6736ec2
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
93 changes: 88 additions & 5 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ end
# val
# end

immutable IteratorCSC{isstored,S<:ContiguousCSC}
A::S
cscrange::UnitRange{Int}

IteratorCSC(A::SparseMatrixCSC, ::Colon, ::Colon) = new(A, 1:length(A.nzval))
end
IteratorCSC(A::ContiguousCSC, ::Colon, j) = IteratorCSC{false,typeof(A)}(A, Colon(), j)
(::Type{IteratorCSC{E}}){E}(A::ContiguousCSC, ::Colon, j) = IteratorCSC{E,typeof(A)}(A, Colon(), j)

immutable ColIteratorCSC{isstored,S<:ContiguousCSC}
A::S
col::Int
Expand Down Expand Up @@ -80,34 +89,108 @@ ColIteratorCSC(A::ContiguousCSC, i, col::Integer) = ColIteratorCSC{false,typeof(
(::Type{ColIteratorCSC{E}}){E}(A::ContiguousCSC, i, col::Integer) = ColIteratorCSC{E,typeof(A)}(A, i, col)

# Iteration when you're visiting every entry
# The iterator state has the following structure:
# The column iterator state has the following structure:
# (row::Int, nextrowval::Ti<:Integer, cscindex::Int)
# nextrowval = A.rowval[cscindex], but we cache it in the state to
# avoid looking it up each time. We use it to decide when the cscindex
# needs to be incremented.
# The full-matrix iterator is similar, except it adds the column:
# (row::Int, col::Int, nextrowval::Ti<:Integer, nextcolval::Ti, cscindex::Int)
length(iter::IteratorCSC{false}) = length(iter.A)
length(iter::ColIteratorCSC{false}) = size(iter.A, 1)
function start(iter::IteratorCSC{false})
cscindex = start(iter.cscrange)
nextrow, nextcol = _nextrowcolval(iter, 0, cscindex)
(1, 1, nextrow, nextcol, cscindex)
end
function start(iter::ColIteratorCSC{false})
cscindex = start(iter.cscrange)
nextrowval = _nextrowval(iter, cscindex)
(1, nextrowval, cscindex)
end
done(iter::IteratorCSC{false}, s) = s[2] > size(iter.A, 2)
done(iter::ColIteratorCSC{false}, s) = s[1] > size(iter.A, 1)
function next{S<:SparseMatrixCSC}(iter::IteratorCSC{false,S}, s)
row, col, nextrowval, nextcolval, cscindex = s
item = IndexCSC(row, col, row==nextrowval && col==nextcolval, cscindex)
newrow = row+1
newcol = col
if newrow > size(iter.A, 1)
newrow = 1
newcol += 1
end
if item.stored
nrv, ncv = _nextrowcolval(iter, col, cscindex+1)
return (item, (newrow, newcol, nrv, ncv, cscindex+1))
end
return (item, (newrow, newcol, nextrowval, nextcolval, cscindex))
end
function next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{false,S}, s)
row, nextrowval, cscindex = s
item = IndexCSC(row, iter.col, row==nextrowval, cscindex)
item.stored ? (item, (row+1, _nextrowval(iter, cscindex+1), cscindex+1)) :
(item, (row+1, nextrowval, cscindex))
end
_nextrowval(iter::ColIteratorCSC, cscindex) = cscindex <= last(iter.cscrange) ? iter.A.rowval[cscindex] : convert(indextype(iter.A), size(iter.A, 1)+1)
function _nextrowval(iter::ColIteratorCSC, cscindex)
if cscindex <= last(iter.cscrange)
return iter.A.rowval[cscindex]
end
convert(indextype(iter.A), size(iter.A, 1)+1) # out-of-bounds fallback
end
function _nextrowcolval(iter::IteratorCSC, col, cscindex)
if cscindex <= last(iter.cscrange)
nextcol = col
nextcscindex = iter.A.colptr[col+1]
if cscindex >= nextcscindex
nextcol = findnext(j->j!=nextcscindex, iter.A.colptr, col+1)-1
end
return (iter.A.rowval[cscindex], nextcol)
end
# out-of-bounds fallback
convert(indextype(iter.A), size(iter.A, 1)+1), size(iter.A, 2)+1
end


# Iteration when you're visting just the stored entries
# We use similar caching tricks with nextcol and nextcolptrindex for IteratorCSC
length(iter::IteratorCSC{true}) = length(iter.cscrange)
length(iter::ColIteratorCSC{true}) = length(iter.cscrange)
function start(iter::IteratorCSC{true})
nextcol = findfirst(j->j!=1, iter.A.colptr)-1
nextcolptrindex = iter.A.colptr[nextcol+1]
(nextcol, nextcolptrindex, start(iter.cscrange))
end
start(iter::ColIteratorCSC{true}) = start(iter.cscrange)
done(iter::IteratorCSC{true}, s) = done(iter.cscrange, s[3])
done(iter::ColIteratorCSC{true}, s) = done(iter.cscrange, s)
next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.rowval[s]; idx = IndexCSC(row, iter.col, true, s); (idx, s+1))
next{S<:SubSparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.parent.rowval[s]; idx = IndexCSC(row, iter.col, true, s); (idx, s+1))
function next{S<:SparseMatrixCSC}(iter::IteratorCSC{true,S}, s)
@inbounds begin
col, nextcolptrindex, cscindex = s
row = iter.A.rowval[cscindex]
end
nextcol = col
if s == nextcolptrindex
tmp = nextcolptrindex # work around julia #15276
nextcol = findnext(j->j!=tmp, iter.A.colptr, col)-1
nextcolptrindex = iter.A.colptr[nextcol+1]
end
idx = IndexCSC(row, col, true, cscindex)
(idx, (nextcol, nextcolptrindex, cscindex+1))
end
function next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s)
@inbounds row = iter.A.rowval[s]
idx = IndexCSC(row, iter.col, true, s)
(idx, s+1)
end
function next{S<:SubSparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s)
@inbounds row = iter.A.parent.rowval[s]
idx = IndexCSC(row, iter.col, true, s)
(idx, s+1)
end

# nextstored{S<:SparseMatrixCSC}(iter::ColIteratorCSC{S}, s, index::Integer) =

each{A<:SparseMatrixCSC,N,isstored}(w::ArrayIndexingWrapper{A,NTuple{N,Colon},true,isstored}) = ColIteratorCSC{isstored}(w.data, w.indexes...) # ambig.
each{A<:SparseMatrixCSC,N,isstored}(w::ArrayIndexingWrapper{A,NTuple{N,Colon},true,isstored}) = IteratorCSC{isstored}(w.data, w.indexes...)
each{A<:SparseMatrixCSC,I,isstored}(w::ArrayIndexingWrapper{A,I,true,isstored}) = ColIteratorCSC{isstored}(w.data, w.indexes...)
each{A<:SparseMatrixCSC,N,isstored}(w::ArrayIndexingWrapper{A,NTuple{N,Colon},false,isstored}) = ValueIterator(w.data, IteratorCSC{isstored}(w.data, w.indexes...))
each{A<:SparseMatrixCSC,I,isstored}(w::ArrayIndexingWrapper{A,I,false,isstored}) = ValueIterator(w.data, ColIteratorCSC{isstored}(w.data, w.indexes...))
24 changes: 22 additions & 2 deletions test/sparse.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,47 @@
A = sparse([1,4,3],[1,1,2],[0.2,0.4,0.6])
A = sparse([2,4,3],[2,2,4],[0.2,0.4,0.6])
Af = full(A)

k = 0
for I in eachindex(stored(A))
@test A[I] == A.nzval[k+=1]
end

k = 0
for j = inds(A, 2)
for I in eachindex(stored(A, :, j))
@test A[I] == A.nzval[k+=1]
end
end

k = 0
for v in each(stored(A))
@test v == A.nzval[k+=1]
end

k = 0
for j = inds(A, 2)
for v in each(stored(A, :, j))
@test v == A.nzval[k+=1]
end
end

k = 0
for I in each(index(A))
@test A[I] == Af[k+=1]
end

k = 0
for j = inds(A, 2)
for I in each(index(A, :, j))
@test A[I] == Af[k+=1]
end
end

k = 0
for v in each(A)
@test v == Af[k+=1]
end

k = 0
for j = inds(A, 2)
for v in each(A, :, j)
Expand Down Expand Up @@ -53,7 +73,7 @@ function matvecmul_val!(b::AbstractVector, A::AbstractMatrix, x::AbstractVector)
b
end

x = [1,-5]
x = [1,-5,7,-13]
btrue = A*x
b = similar(btrue)
matvecmul_ind!(b, A, x)
Expand Down

0 comments on commit 6736ec2

Please sign in to comment.