Skip to content

Commit

Permalink
Merge pull request #2 from timholy/teh/dense
Browse files Browse the repository at this point in the history
Generic fallback for each (supports dense arrays)
  • Loading branch information
timholy committed Apr 18, 2016
2 parents 37af491 + c2ee440 commit cd31dd3
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 24 deletions.
41 changes: 32 additions & 9 deletions src/ArrayIterationPlayground.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ArrayIterationPlayground

using Base: ViewIndex
import Base: getindex, setindex!, start, next, done, eachindex
import Base: getindex, setindex!, start, next, done, length, eachindex

export inds, index, stored, each

Expand All @@ -10,13 +10,6 @@ export inds, index, stored, each
inds(A::AbstractArray, d) = 1:size(A, d)
inds{T,N}(A::AbstractArray{T,N}) = ntuple(d->inds(A,d), Val{N})

immutable ValueIterator{I}
iter::I
end
start(iter::ValueIterator) = start(iter.iter)
done(iter::ValueIterator, s) = done(iter.iter, s)
next(iter::ValueIterator, s) = ((item, s) = next(iter.iter, s); (value(iter.iter, item), s))

eachindex(x...) = each(index(x...))

# isindex == true => want the indexes (keys) of the array
Expand All @@ -41,7 +34,37 @@ stored(A::AbstractArray) = stored(A, allindexes(A))
stored(A::AbstractArray, I::ViewIndex...) = stored(A, I)
stored{T,N}(A::AbstractArray{T,N}, indexes::NTuple{N,ViewIndex}) = ArrayIndexingWrapper{typeof(A),typeof(indexes),false,true}(A, indexes)

each(A::AbstractArray, indexes...) = ValueIterator(each(index(A, indexes)))
"""
`each(obj)`
`each(obj, indexes...)`
`each` instantiates the iterator associated with `obj`. In conjunction
with `index` and `stored`, you may choose to iterate over either
indexes or values, and o ver all elements or just the stored elements.
"""
each(A::AbstractArray) = each(A, allindexes(A))
each(A::AbstractArray, indexes::ViewIndex...) = each(A, indexes)
each{T,N}(A::AbstractArray{T,N}, indexes::NTuple{N,ViewIndex}) = each(ArrayIndexingWrapper{typeof(A),typeof(indexes),false,false}(A, indexes))

# Internal type for storing instantiated index iterators but returning
# array values
immutable ValueIterator{A<:AbstractArray,I}
data::A
iter::I
end

each{A,I,stored}(W::ArrayIndexingWrapper{A,I,false,stored}) = (itr = each(index(W)); ValueIterator{A,typeof(itr)}(W.data, itr))
each{A,I}(W::ArrayIndexingWrapper{A,I,true}) = CartesianRange(ranges(W))

start(vi::ValueIterator) = start(vi.iter)
done(vi::ValueIterator, s) = done(vi.iter, s)
next(vi::ValueIterator, s) = ((idx, s) = next(vi.iter, s); (vi.data[idx], s))

ranges(W) = ranges((), W.data, 1, W.indexes...)
ranges(out, A, d) = out
@inline ranges(out, A, d, i, I...) = ranges((out..., i), A, d+1, I...)
@inline ranges(out, A, d, i::Colon, I...) = ranges((out..., inds(A, d)), A, d+1, I...)


immutable SyncedIterator{I,F<:Tuple{Vararg{Function}}}
iter::I
Expand Down
12 changes: 5 additions & 7 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ immutable ColIndexCSC
cscindex::Int # for stored value, the index into rowval & nzval
end

@inline getindex(A::SparseMatrixCSC, i::ColIndexCSC, j::Integer) = (@inbounds ret = i.stored ? A.nzval[i.cscindex] : zero(eltype(A)); ret)
@inline getindex(A::SubSparseMatrixCSC, i::ColIndexCSC, j::Integer) = A.parent[i, j]
@inline getindex(A::SparseMatrixCSC, i::ColIndexCSC) = (@inbounds ret = i.stored ? A.nzval[i.cscindex] : zero(eltype(A)); ret)
@inline getindex(A::SubSparseMatrixCSC, i::ColIndexCSC) = A.parent[i]
# @inline function getindex(a::AbstractVector, i::ColIndexCSC)
# @boundscheck 1 <= i.rowval <= length(a)
# @inbounds ret = a[i.rowval]
# ret
# end

@inline setindex!(A::SparseMatrixCSC, val, i::ColIndexCSC, j::Integer) = (@inbounds A.nzval[i.cscindex] = val; val)
@inline setindex!(A::SubSparseMatrixCSC, val, i::ColIndexCSC, j::Integer) = A.parent[i,j] = val
@inline setindex!(A::SparseMatrixCSC, val, i::ColIndexCSC) = (@inbounds A.nzval[i.cscindex] = val; val)
@inline setindex!(A::SubSparseMatrixCSC, val, i::ColIndexCSC) = A.parent[i] = val
# @inline function setindex!(a::AbstractVector, val, i::ColIndexCSC)
# @boundscheck 1 <= i.rowval <= length(a) || throw(BoundsError(a, i.rowval))
# @inbounds a[i.rowval] = val
Expand Down Expand Up @@ -99,10 +99,8 @@ done(iter::ColIteratorCSC{true}, s) = done(iter.cscrange, s)
next{S<:SparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.rowval[s]; idx = ColIndexCSC(row, true, s); (idx, s+1))
next{S<:SubSparseMatrixCSC}(iter::ColIteratorCSC{true,S}, s) = (@inbounds row = iter.A.parent.rowval[s]; idx = ColIndexCSC(row, true, s); (idx, s+1))

value(iter::ColIteratorCSC, i) = iter.A[i, iter.col]

# nextstored{S<:SparseMatrixCSC}(iter::ColIteratorCSC{S}, s, index::Integer) =

each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,true,false}) = ColIteratorCSC{false}(w.data, w.indexes...)
each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,true,true}) = ColIteratorCSC{true}(w.data, w.indexes...)
each{A<:SparseMatrixCSC,I}(w::ArrayIndexingWrapper{A,I,false}) = ValueIterator(ColIteratorCSC{true}(w.data, w.indexes...))
each{A<:SparseMatrixCSC,I,isstored}(w::ArrayIndexingWrapper{A,I,false,isstored}) = ValueIterator(w.data, ColIteratorCSC{isstored}(w.data, w.indexes...))
34 changes: 34 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
A = [1 5 -5;
0 3 2]

@test each(index(A)) == CartesianRange((1:2, 1:3))
@test each(index(A, :, 1:2)) == CartesianRange((1:2, 1:2))
@test each(index(A, :, 2:3)) == CartesianRange((1:2, 2:3))

k = 0
for j in inds(A, 2)
for v in each(A, :, j)
@test v == A[k+=1]
end
end

k = 0
for j in inds(A, 2)
for I in each(index(A, :, j))
@test A[I] == A[k+=1]
end
end

k = 0
for j in inds(A, 2)
for I in eachindex(stored(A, :, j))
@test A[I] == A[k+=1]
end
end

k = 0
for j in inds(A, 2)
for v in each(stored(A, :, j))
@test v == A[k+=1]
end
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
using ArrayIterationPlayground
using Base.Test

A = zeros(2,3)
@test inds(A, 1) == 1:2
@test inds(A, 2) == 1:3
@test inds(A, 3) == 1:1
@test inds(A) == (1:2, 1:3)

include("dense.jl")
include("sparse.jl")
16 changes: 8 additions & 8 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@ A = sparse([1,4,3],[1,1,2],[0.2,0.4,0.6])
Af = full(A)

k = 0
for j = 1:2
for i in eachindex(stored(A, :, j))
@test A[i,j] == A.nzval[k+=1]
for j = inds(A, 2)
for I in eachindex(stored(A, :, j))
@test A[I] == A.nzval[k+=1]
end
end

k = 0
for j = 1:2
for j = inds(A, 2)
for v in each(stored(A, :, j))
@test v == A.nzval[k+=1]
end
end

k = 0
for j = 1:2
for i in each(index(A, :, j))
@test A[i,j] == Af[k+=1]
for j = inds(A, 2)
for I in each(index(A, :, j))
@test A[I] == Af[k+=1]
end
end

k = 0
for j = 1:2
for j = inds(A, 2)
for v in each(A, :, j)
@test v == Af[k+=1]
end
Expand Down

0 comments on commit cd31dd3

Please sign in to comment.