-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from timholy/teh/implementation
Partial implemention of general API and SparseMatrixCSC iterators
- Loading branch information
Showing
6 changed files
with
202 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ os: | |
- linux | ||
- osx | ||
julia: | ||
- release | ||
- nightly | ||
notifications: | ||
email: false | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,66 @@ | ||
module ArrayIterationPlayground | ||
|
||
# package code goes here | ||
using Base: ViewIndex | ||
import Base: getindex, setindex!, start, next, done, eachindex | ||
|
||
export inds, index, stored, each | ||
|
||
# General API | ||
|
||
inds(A::AbstractArray, d) = 1:size(A, d) | ||
inds{T,N}(A::AbstractArray{T,N}) = ntuple(d->inds(A,d), Val{N}) | ||
|
||
immutable ValueIterator{I} | ||
iter::I | ||
end | ||
start(iter::ValueIterator) = start(iter.iter) | ||
done(iter::ValueIterator, s) = done(iter.iter, s) | ||
next(iter::ValueIterator, s) = ((item, s) = next(iter.iter, s); (value(iter.iter, item), s)) | ||
|
||
eachindex(x...) = each(index(x...)) | ||
|
||
# isindex == true => want the indexes (keys) of the array | ||
# isindex == false => want the values of the array | ||
# isstored == true => visit only stored entries | ||
# isstored == false => visit all indexes | ||
immutable ArrayIndexingWrapper{A, I<:Tuple{Vararg{ViewIndex}}, isindex, isstored} | ||
data::A | ||
indexes::I | ||
end | ||
|
||
index{A,I,isindex,isstored}(w::ArrayIndexingWrapper{A,I,isindex,isstored}) = ArrayIndexingWrapper{A,I,true,isstored}(w.data, w.indexes) | ||
stored{A,I,isindex,isstored}(w::ArrayIndexingWrapper{A,I,isindex,isstored}) = ArrayIndexingWrapper{A,I,isindex,true}(w.data, w.indexes) | ||
|
||
allindexes{T,N}(A::AbstractArray{T,N}) = ntuple(d->Colon(),Val{N}) | ||
|
||
index(A::AbstractArray) = index(A, allindexes(A)) | ||
index(A::AbstractArray, I::ViewIndex...) = index(A, I) | ||
index{T,N}(A::AbstractArray{T,N}, indexes::NTuple{N,ViewIndex}) = ArrayIndexingWrapper{typeof(A),typeof(indexes),true,false}(A, indexes) | ||
|
||
stored(A::AbstractArray) = stored(A, allindexes(A)) | ||
stored(A::AbstractArray, I::ViewIndex...) = stored(A, I) | ||
stored{T,N}(A::AbstractArray{T,N}, indexes::NTuple{N,ViewIndex}) = ArrayIndexingWrapper{typeof(A),typeof(indexes),false,true}(A, indexes) | ||
|
||
each(A::AbstractArray, indexes...) = ValueIterator(each(index(A, indexes))) | ||
|
||
immutable SyncedIterator{I,F<:Tuple{Vararg{Function}}} | ||
iter::I | ||
itemfuns::F | ||
end | ||
|
||
start(iter::SyncedIterator) = start(iter.iter) | ||
next(iter::SyncedIterator, state) = mapf(iter.itemfuns, state), next(iter.iter, state) | ||
done(iter::SyncedIterator, state) = done(iter.iter, state) | ||
|
||
""" | ||
`mapf(fs, x)` is similar to `map`, except instead of mapping one | ||
function over many objects, it maps many functions over one | ||
object. `fs` should be a tuple-of-functions. | ||
""" | ||
@inline mapf(fs::Tuple, x) = _mapf((), x, fs...) | ||
_mapf(out, x) = out | ||
@inline _mapf(out, x, f, fs...) = _mapf((out..., f(x)), x, fs...) | ||
|
||
include("sparse.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
### Sparse-array iterators | ||
|
||
## SparseMatrixCSC | ||
|
||
typealias SubSparseMatrixCSC{I,T,N,P<:SparseMatrixCSC} SubArray{T,N,P,I,false} | ||
typealias ContiguousCSC{I<:Tuple{Union{Colon,UnitRange{Int}},Any},T,N,P<:SparseMatrixCSC} Union{P,SubSparseMatrixCSC{I,T,N,P}} | ||
|
||
indextype{Tv,Ti}(::Type{SparseMatrixCSC{Tv,Ti}}) = Ti | ||
indextype(A::SparseMatrixCSC) = indextype(typeof(A)) | ||
|
||
# Indexing along a particular column | ||
immutable ColIndexCSC | ||
row::Int # where you are currently (might not be stored) | ||
stored::Bool # true if this represents a stored value | ||
cscindex::Int # for stored value, the index into rowval & nzval | ||
end | ||
|
||
@inline getindex(A::SparseMatrixCSC, i::ColIndexCSC, j::Integer) = (@inbounds ret = i.stored ? A.nzval[i.cscindex] : zero(eltype(A)); ret) | ||
@inline getindex(A::SubSparseMatrixCSC, i::ColIndexCSC, j::Integer) = A.parent[i, j] | ||
# @inline function getindex(a::AbstractVector, i::ColIndexCSC) | ||
# @boundscheck 1 <= i.rowval <= length(a) | ||
# @inbounds ret = a[i.rowval] | ||
# ret | ||
# end | ||
|
||
@inline setindex!(A::SparseMatrixCSC, val, i::ColIndexCSC, j::Integer) = (@inbounds A.nzval[i.cscindex] = val; val) | ||
@inline setindex!(A::SubSparseMatrixCSC, val, i::ColIndexCSC, j::Integer) = A.parent[i,j] = val | ||
# @inline function setindex!(a::AbstractVector, val, i::ColIndexCSC) | ||
# @boundscheck 1 <= i.rowval <= length(a) || throw(BoundsError(a, i.rowval)) | ||
# @inbounds a[i.rowval] = val | ||
# val | ||
# end | ||
|
||
immutable ColIteratorCSC{isstored,S<:ContiguousCSC} | ||
A::S | ||
col::Int | ||
cscrange::UnitRange{Int} | ||
|
||
function ColIteratorCSC(A::SparseMatrixCSC, ::Colon, col::Integer) | ||
@boundscheck 1 <= col <= size(A, 2) || throw(BoundsError(A, (:,col))) | ||
@inbounds r = A.colptr[col]:A.colptr[col+1]-1 | ||
new(A, col, r) | ||
end | ||
function ColIteratorCSC{I<:Tuple{Colon,Any}}(A::SubSparseMatrixCSC{I}, ::Colon, col::Integer) | ||
@boundscheck 1 <= col <= size(A, 2) || throw(BoundsError(A, (:,col))) | ||
@inbounds j = A.indexes[2][col] | ||
@inbounds r = A.parent.colptr[j]:A.parent.colptr[j+1]-1 | ||
new(A, col, r) | ||
end | ||
function ColIteratorCSC{I<:Tuple{UnitRange{Int},Any}}(A::SubSparseMatrixCSC{I}, ::Colon, col::Integer) | ||
@boundscheck 1 <= col <= size(A, 2) || throw(BoundsError(A, (:,col))) | ||
@inbounds j = A.indexes[2][col] | ||
@inbounds r1, r2 = Int(A.parent.colptr[j]), Int(A.parent.colptr[j+1]-1) | ||
rowval = A.parent.rowval | ||
i = A.indexes[1] | ||
r1 = searchsortedfirst(rowval, first(i), r1, r2, Forward) | ||
r1 <= r2 && (r2 = searchsortedlast(rowval, last(i), r1, r2, Forward)) | ||
new(A, col, r1:r2) | ||
end | ||
function ColIteratorCSC(A::SparseMatrixCSC, i::UnitRange, col::Integer) | ||
@boundscheck 1 <= col <= size(A, 2) || throw(BoundsError(A, (i,col))) | ||
@boundscheck (1 <= first(i) && last(i) <= size(A, 1)) || throw(BoundsError(A, (i,col))) | ||
@inbounds r1, r2 = Int(A.parent.colptr[j]), Int(A.parent.colptr[j+1]-1) | ||
rowval = A.parent.rowval | ||
r1 = searchsortedfirst(rowval, first(i), r1, r2, Forward) | ||
r1 <= r2 && (r2 = searchsortedlast(rowval, last(i), r1, r2, Forward)) | ||
new(A, col, r1:r2) | ||
end | ||
end | ||
# Default is to visit each site, not just the stored sites | ||
ColIteratorCSC(A::ContiguousCSC, i, col::Integer) = ColIteratorCSC{false,typeof(A)}(A, i, col) | ||
# Choose with ColIteratorCSC{true/false}(A, col) | ||
(::Type{ColIteratorCSC{E}}){E}(A::ContiguousCSC, i, col::Integer) = ColIteratorCSC{E,typeof(A)}(A, i, col) | ||
|
||
# Iteration when you're visiting every entry | ||
# The iterator state has the following structure: | ||
# (row::Int, nextrowval::Ti<:Integer, cscindex::Int) | ||
# nextrowval = A.rowval[cscindex], but we cache it in the state to | ||
# avoid looking it up each time. We use it to decide when the cscindex | ||
# needs to be incremented. | ||
length(iter::ColIteratorCSC{false}) = size(iter.A, 1) | ||
function start(iter::ColIteratorCSC{false}) | ||
cscindex = start(iter.cscrange) | ||
nextrowval = _nextrowval(iter, cscindex) | ||
(1, nextrowval, cscindex) | ||
end | ||
done(iter::ColIteratorCSC{false}, s) = s[1] > size(iter.A, 1) | ||
function next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{false,S}, s) | ||
row, nextrowval, cscindex = s | ||
item = ColIndexCSC(row, row==nextrowval, cscindex) | ||
item.stored ? (item, (row+1, _nextrowval(iter, cscindex+1), cscindex+1)) : | ||
(item, (row+1, nextrowval, cscindex)) | ||
end | ||
_nextrowval(iter::ColIteratorCSC, cscindex) = cscindex <= last(iter.cscrange) ? iter.A.rowval[cscindex] : convert(indextype(iter.A), size(iter.A, 1)+1) | ||
|
||
length(iter::ColIteratorCSC{true}) = length(iter.cscrange) | ||
start(iter::ColIteratorCSC{true}) = start(iter.cscrange) | ||
done(iter::ColIteratorCSC{true}, s) = done(iter.cscrange, s) | ||
next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.rowval[s]; idx = ColIndexCSC(row, true, s); (idx, s+1)) | ||
next{S<:SubSparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.parent.rowval[s]; idx = ColIndexCSC(row, true, s); (idx, s+1)) | ||
|
||
value(iter::ColIteratorCSC, i) = iter.A[i, iter.col] | ||
|
||
# nextstored{S<:SparseMatrixCSC}(iter::ColIteratorCSC{S}, s, index::Integer) = | ||
|
||
each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,true,false}) = ColIteratorCSC{false}(w.data, w.indexes...) | ||
each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,true,true}) = ColIteratorCSC{true}(w.data, w.indexes...) | ||
each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,false}) = ValueIterator(ColIteratorCSC{true}(w.data, w.indexes...)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
using ArrayIterationPlayground | ||
using Base.Test | ||
|
||
# write your own tests here | ||
@test 1 == 1 | ||
include("sparse.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
A = sparse([1,4,3],[1,1,2],[0.2,0.4,0.6]) | ||
Af = full(A) | ||
|
||
k = 0 | ||
for j = 1:2 | ||
for i in eachindex(stored(A, :, j)) | ||
@test A[i,j] == A.nzval[k+=1] | ||
end | ||
end | ||
|
||
k = 0 | ||
for j = 1:2 | ||
for v in each(stored(A, :, j)) | ||
@test v == A.nzval[k+=1] | ||
end | ||
end | ||
|
||
k = 0 | ||
for j = 1:2 | ||
for i in each(index(A, :, j)) | ||
@test A[i,j] == Af[k+=1] | ||
end | ||
end | ||
|
||
k = 0 | ||
for j = 1:2 | ||
for v in each(A, :, j) | ||
@test v == Af[k+=1] | ||
end | ||
end |