Skip to content

Commit

Permalink
First working implementation of sync(stored(...))
Browse files Browse the repository at this point in the history
This is very specific, so needs to be generalized
  • Loading branch information
timholy committed Apr 29, 2016
1 parent 08a148c commit ee82348
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 16 deletions.
4 changes: 3 additions & 1 deletion src/ArrayIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ module ArrayIteration
import Base: getindex, setindex!, start, next, done, length, eachindex, show, parent, isless
using Base: ReshapedArray, ReshapedIndex, linearindexing, LinearFast, LinearSlow, LinearIndexing
using Base.PermutedDimsArrays: PermutedDimsArray
using Base.Order

export inds, index, value, stored, each, sync
export Follower, inds, index, value, stored, each, sync

include("types.jl")
include("core.jl")
include("reshaped.jl")
include("sparse.jl")
include("sync_stored.jl")

end # module
42 changes: 28 additions & 14 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ function show(io::IO, iter::ContigCartIterator)
end

parent(W::ArrayIndexingWrapper) = W.data
parent(F::Follower) = F.value

stripwrapper(A::AbstractArray) = A
stripwrapper(A::ArrayIndexingWrapper) = parent(A)

"""
`index(A)`
Expand Down Expand Up @@ -97,7 +101,10 @@ done(vi::ValueIterator, s) = done(vi.iter, s)
next(vi::ValueIterator, s) = ((idx, s) = next(vi.iter, s); (vi.data[idx], s))

start(iter::SyncedIterator) = start(iter.iter)
next(iter::SyncedIterator, state) = mapf(iter.itemfuns, state), next(iter.iter, state)
function next(iter::SyncedIterator, state)
item, newstate = next(iter.iter, state)
mapf(iter.itemfuns, iter.items, item), newstate
end
done(iter::SyncedIterator, state) = done(iter.iter, state)

start(itr::FirstToLastIterator) = (itr.itr, start(itr.itr))
Expand All @@ -107,6 +114,8 @@ function next(itr::FirstToLastIterator, i)
end
done(itr::FirstToLastIterator, i) = done(i[1], i[2])

# SyncedIterator(iter, funcs) = SyncedIterator{typeof(iter), Base.typesof(funcs)}(iter, funcs)

function sync(A::AllElements, B::AllElements)
checksame_inds(A, B)
_sync(checksame_storageorder(A, B), A, B)
Expand All @@ -122,24 +131,21 @@ _sync(::Type{Val{false}}, A, B) = zip(columnmajoriterator(A), columnmajoriterato
_sync(::Type{Val{true}}, As...) = zip(map(each, As)...)
_sync(::Type{Val{false}}, As...) = zip(map(columnmajoriterator, As)...)

sync(A::StoredElements, B::StoredElements) = sync_stored(A, B)
sync(A, B::StoredElements) = sync_stored(A, B)
sync(A::StoredElements, B) = sync_stored(A, B)

#function sync_stored(A, B)
# checksame_inds(A, B)
#end
# For stored, see sync_stored.jl

### Utility methods

"""
`mapf(fs, x)` is similar to `map`, except instead of mapping one
function over many objects, it maps many functions over one
object. `fs` should be a tuple-of-functions.
`mapf(fs, objs, x)` is similar to `map(f, a, b)`, except instead of mapping one
function over many objects, it maps many function/object pairs over one
`x`. `fs` should be a tuple-of-functions, and `objs` a tuple-of-containers.
"""
@inline mapf(fs::Tuple, x) = _mapf((), x, fs...)
_mapf(out, x) = out
@inline _mapf(out, x, f, fs...) = _mapf((out..., f(x)), x, fs...)
mapf{N}(fs::NTuple{N}, objs::NTuple{N}, x) = _mapf((), fs, objs, x)
_mapf(out, ::Tuple{}, ::Tuple{}, x) = out
@inline function _mapf(out, fs, objs, x)
f, obj = fs[1], objs[1]
ret = _mapf((out..., f(obj, x)), Base.tail(fs), Base.tail(objs), x)
end

storageorder(::Array) = FirstToLast()
storageorder{T,N,AA,perm}(::PermutedDimsArray{T,N,AA,perm}) = OtherOrder{perm}()
Expand Down Expand Up @@ -175,6 +181,13 @@ _extent_inds(out, A, d) = out
@inline _extent_inds(out, A, d, ::Int, indexes...) = _extent_inds(out, A, d+1, indexes...)
@inline _extent_inds(out, A, d, i, indexes...) = _extent_inds((out..., inds(A, d)), A, d+1, indexes...)

# extent_dims indicates which dimensions have extended size
extent_dims{T,N}(A::AbstractArray{T,N}) = ntuple(identity,Val{N})
extent_dims(W::ArrayIndexingWrapper) = _extent_dims((), 1, W.indexes...)
_extent_dims(out, d::Integer) = out
@inline _extent_dims(out, d, i1::Union{UnitRange{Int},Colon}, indexes...) = _extent_dims((out..., d), d+1, indexes...)
@inline _extent_dims(out, d, i1, indexes...) = _extent_dims(out, d+1, indexes...)

columnmajoriterator(A::AbstractArray) = columnmajoriterator(linearindexing(A), A)
columnmajoriterator(::LinearFast, A) = A
columnmajoriterator(::LinearSlow, A) = FirstToLastIterator(A, CartesianRange(size(A)))
Expand Down Expand Up @@ -207,6 +220,7 @@ function _contiguous_iterator(W, ::LinearFast)
end
_contiguous_iterator(W, ::LinearSlow) = CartesianRange(ranges(W))

# Return the "corners" of an iteration range
function firstlast(W)
A = parent(W)
f = firstlast(first, A, W.indexes)
Expand Down
7 changes: 7 additions & 0 deletions src/sparse.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
### Sparse-array iterators

typealias SubSparseArray{I,T,N,P<:AbstractSparseArray} SubArray{T,N,P,I,false}

## SparseMatrixCSC

typealias SubSparseMatrixCSC{I,T,N,P<:SparseMatrixCSC} SubArray{T,N,P,I,false}
Expand All @@ -15,6 +17,11 @@ immutable IndexCSC
cscindex::Int # for stored value, the index into rowval & nzval
end

function getindex(I::IndexCSC, d)
@boundscheck d==1 || d==2 || Base.throw_boundserror(I, d)
ifelse(d == 1, I.row, I.col)
end

@inline getindex(A::SparseMatrixCSC, i::IndexCSC) = (@inbounds ret = i.stored ? A.nzval[i.cscindex] : zero(eltype(A)); ret)
@inline getindex(A::SubSparseMatrixCSC, i::IndexCSC) = A.parent[i]
# @inline function getindex(a::AbstractVector, i::IndexCSC)
Expand Down
19 changes: 19 additions & 0 deletions src/sync_stored.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#function sync{AA,I}(FA::Follower, B::ArrayIndexingWrapper{AA,I,true})
function sync(FA::Follower, B)
A = parent(FA)
checksame_inds(A, B)
iter, Bind, Bfunc = syncable(B)
SyncedIterator(each(iter), (stripwrapper(A), Bind), (synciterfunc(A, iter), Bfunc))
end

# value-iterator
syncable(A) = index(A), stripwrapper(A), (A, i) -> (@inbounds ret = A[i]; ret)
# syncable{AA,I}(A::ArrayIndexingWrapper{AA,I,false}) = A, parent(A), (A, i) -> (println("2"); @inbounds ret = A[i]; ret)
# index-iterator
syncable{AA,I}(A::ArrayIndexingWrapper{AA,I,true}) = A, parent(A), (A, i) -> i

synciterfunc(A, B) = _synciterfunc(A, extent_dims(B))
# value-iterator
_synciterfunc(A, d::Tuple{Int}) = (A, i) -> A[i[d[1]]]
# index-iterator
_synciterfunc{AA,I}(A::ArrayIndexingWrapper{AA,I,true}, d::Tuple{Int}) = (A, i) -> i[d[1]]
9 changes: 8 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ immutable ValueIterator{A<:AbstractArray,I}
iter::I
end

immutable SyncedIterator{I,F<:Tuple{Vararg{Function}}}
immutable SyncedIterator{I,O<:Tuple,F<:Tuple{Vararg{Function}}}
iter::I
items::O
itemfuns::F
end

# declare that an array/iterhint should not control which index
# positions are visited, but only follow the lead of other objects
immutable Follower{T}
value::T
end

typealias ArrayOrWrapper Union{AbstractArray,ArrayIndexingWrapper}
typealias AllElements{A,I,isindex} Union{AbstractArray,ArrayIndexingWrapper{A,I,isindex,false}}
typealias StoredElements{A,I,isindex} ArrayIndexingWrapper{A,I,isindex,true}
Expand Down
36 changes: 36 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,39 @@ for j = inds(A, 2)
@test v == Af[k+=1]
end
end

# Sparse matrix-vector multiplication
function matvecmul_ind!(b::AbstractVector, A::AbstractMatrix, x::AbstractVector)
fill!(b, 0)
inds(A, 2) == inds(x, 1) || throw(DimensionMismatch("inds(A, 2) = $(inds(A, 2)) does not agree with inds(x, 1) = $(inds(x, 1))"))
for j in inds(A, 2)
xj = x[j]
for (ib, iA) in sync(Follower(index(b)), index(stored(A, :, j)))
b[ib] += A[iA]*xj
end
end
b
end
function matvecmul_val!(b::AbstractVector, A::AbstractMatrix, x::AbstractVector)
fill!(b, 0)
inds(A, 2) == inds(x, 1) || throw(DimensionMismatch("inds(A, 2) = $(inds(A, 2)) does not agree with inds(x, 1) = $(inds(x, 1))"))
for j in inds(A, 2)
xj = x[j]
for (ib, a) in sync(Follower(index(b)), stored(A, :, j))
b[ib] += a*xj
end
end
b
end

x = [1,-5]
btrue = A*x
b = similar(btrue)
matvecmul_ind!(b, A, x)
@test_approx_eq b btrue
matvecmul_val!(b, A, x)
@test_approx_eq b btrue
@test_throws DimensionMismatch matvecmul_ind!(b, A, [1])
@test_throws DimensionMismatch matvecmul_ind!([0.1,0.2], A, x)
@test_throws DimensionMismatch matvecmul_val!(b, A, [1])
@test_throws DimensionMismatch matvecmul_val!([0.1,0.2], A, x)

0 comments on commit ee82348

Please sign in to comment.