Skip to content

Commit

Permalink
Add a multidimensional (cartesian) iterator
Browse files Browse the repository at this point in the history
Closes #1917, closes #6437
  • Loading branch information
timholy committed Nov 12, 2014
1 parent cad4eaa commit 0314913
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 3 deletions.
2 changes: 1 addition & 1 deletion base/dates/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ function in{T<:TimeType}(x::T, r::StepRange{T})
end

Base.start{T<:TimeType}(r::StepRange{T}) = 0
Base.next{T<:TimeType}(r::StepRange{T}, i) = (r.start+r.step*i,i+1)
Base.next{T<:TimeType}(r::StepRange{T}, i::Int) = (r.start+r.step*i,i+1)
Base.done{T<:TimeType,S<:Period}(r::StepRange{T,S}, i::Integer) = length(r) <= i
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ export
cumsum,
cumsum!,
cumsum_kbn,
eachelement,
eachindex,
extrema,
fill!,
fill,
Expand Down
119 changes: 119 additions & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,122 @@
### Multidimensional iterators
module IteratorsMD

import Base: start, done, next, getindex, setindex!
import Base: @nref, @ncall, @nif, @nexprs

export eachelement, eachindex, linearindexing, LinearFast

# Traits for linear indexing
abstract LinearIndexing
immutable LinearFast <: LinearIndexing end
immutable LinearSlow <: LinearIndexing end

linearindexing(::AbstractArray) = LinearSlow()
linearindexing(::Array) = LinearFast()
linearindexing(::BitArray) = LinearFast()
linearindexing(::Range) = LinearFast()

# this generates types like this:
# immutable Subscripts_3 <: Subscripts{3}
# I_1::Int
# I_2::Int
# I_3::Int
# end
# they are used as iterator states
# TODO: when tuples get improved, replace with a tuple-based implementation. See #6437.

abstract Subscripts{N} # the state for all multidimensional iterators
abstract SizeIterator{N} # Iterator that visits the index associated with each element

function gen_iterators(N::Int, with_shared=true)
# Create the types
namestate = symbol("Subscripts_$N")
namesize = symbol("SizeIterator_$N")
fieldnames = [symbol("I_$i") for i = 1:N]
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
exstate = Expr(:type, false, Expr(:(<:), namestate, Expr(:curly, :Subscripts, N)), Expr(:block, fields...))
dimsindexes = Expr[:(dims[$i]) for i = 1:N]
onesN = ones(Int, N)
infsN = fill(typemax(Int), N)
anyzero = Expr(:(||), [:(SZ.I.$(fieldnames[i]) == 0) for i = 1:N]...)
# Some necessary ambiguity resolution
exrange = N != 1 ? nothing : quote
next(R::StepRange, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1)
next{T}(R::UnitRange{T}, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1)
end
exshared = !with_shared ? nothing : quote
getindex{T}(S::SharedArray{T,$N}, state::$namestate) = S.s[state]
setindex!{T}(S::SharedArray{T,$N}, v, state::$namestate) = S.s[state] = v
end
quote
$exstate
immutable $namesize <: SizeIterator{$N}
I::$namestate
end
$namestate(dims::NTuple{$N,Int}) = $namestate($(dimsindexes...))
_eachindex(dims::NTuple{$N,Int}) = $namesize($namestate(dims))

start{T}(AT::(AbstractArray{T,$N},LinearSlow)) = isempty(AT[1]) ? $namestate($(infsN...)) : $namestate($(onesN...))
start(SZ::$namesize) = $anyzero ? $namestate($(infsN...)) : $namestate($(onesN...))

$exrange

@inline function next{T}(A::AbstractArray{T,$N}, state::$namestate)
@inbounds v = A[state]
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
v, newstate
end
@inline function next(iter::$namesize, state::$namestate)
newstate = @nif $N d->(getfield(state,d) < getfield(iter.I,d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
state, newstate
end

$exshared
getindex{T}(A::AbstractArray{T,$N}, state::$namestate) = @nref $N A d->getfield(state,d)
setindex!{T}(A::AbstractArray{T,$N}, v, state::$namestate) = (@nref $N A d->getfield(state,d)) = v
end
end

# Ambiguity resolution
done(R::StepRange, I::Subscripts{1}) = getfield(I, 1) > length(R)
done(R::UnitRange, I::Subscripts{1}) = getfield(I, 1) > length(R)

Base.start(A::AbstractArray) = start((A,linearindexing(A)))
start(::(AbstractArray,LinearFast)) = 1
done{T,N}(A::AbstractArray{T,N}, I::Subscripts{N}) = getfield(I, N) > size(A, N)
done{N}(iter::SizeIterator{N}, I::Subscripts{N}) = getfield(I, N) > getfield(iter.I, N)

eachindex(A::AbstractArray) = eachindex(size(A))

let implemented = IntSet()
global eachindex
global eachelement
function eachindex{N}(t::NTuple{N,Int})
if !in(N, implemented)
eval(gen_iterators(N))
end
_eachindex(t)
end
function eachelement{T,N}(A::AbstractArray{T,N})
if !in(N, implemented)
eval(gen_iterators(N))
end
A
end
end

# Pre-generate for low dimensions
for N = 1:8
eval(gen_iterators(N, false))
eval(:(eachindex(t::NTuple{$N,Int}) = _eachindex(t)))
eval(:(eachelement{T}(A::AbstractArray{T,$N}) = A))
end

end # IteratorsMD

using .IteratorsMD


### From array.jl

@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
Expand Down
4 changes: 2 additions & 2 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ copy(r::Range) = r
## iteration

start(r::FloatRange) = 0
next{T}(r::FloatRange{T}, i) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
done(r::FloatRange, i) = (length(r) <= i)
next{T}(r::FloatRange{T}, i::Int) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
done(r::FloatRange, i::Int) = (length(r) <= i)

# NOTE: For ordinal ranges, we assume start+step might be from a
# lifted domain (e.g. Int8+Int8 => Int); use that for iterating.
Expand Down
8 changes: 8 additions & 0 deletions base/sharedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ end
convert(::Type{Array}, S::SharedArray) = S.s

# # pass through getindex and setindex! - they always work on the complete array unlike DArrays
for N = 1:8
name = symbol("Subscripts_$N")
@eval begin
getindex{T}(S::SharedArray{T,$N}, I::IteratorsMD.$name) = getindex(S.s, I)
setindex!{T}(S::SharedArray{T,$N}, v, I::IteratorsMD.$name) = setindex!(S.s, v, I)
end
end

getindex(S::SharedArray) = getindex(S.s)
getindex(S::SharedArray, I::Real) = getindex(S.s, I)
getindex(S::SharedArray, I::AbstractArray) = getindex(S.s, I)
Expand Down
51 changes: 51 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,54 @@ end
b718cbc = 5
@test b718cbc[1.0] == 5
@test_throws InexactError b718cbc[1.1]

# Multidimensional iterators
function mdsum(A)
s = 0.0
for a in eachelement(A)
s += a
end
s
end

function mdsum2(A)
s = 0.0
@inbounds for I in eachindex(A)
s += A[I]
end
s
end

a = [1:5]
@test isa(Base.linearindexing(a), Base.LinearFast)
b = sub(a, :)
@test isa(Base.linearindexing(b), Base.IteratorsMD.LinearSlow)
shp = [5]
for i = 1:10
A = reshape(a, tuple(shp...))
@test mdsum(A) == 15
@test mdsum2(A) == 15
B = sub(A, ntuple(i, i->Colon())...)
@test mdsum(B) == 15
@test mdsum2(B) == 15
unshift!(shp, 1)
end

a = [1:10]
shp = [2,5]
for i = 2:10
A = reshape(a, tuple(shp...))
@test mdsum(A) == 55
@test mdsum2(A) == 55
B = sub(A, ntuple(i, i->Colon())...)
@test mdsum(B) == 55
@test mdsum2(B) == 55
insert!(shp, 2, 1)
end

a = ones(0,5)
b = sub(a, :, :)
@test mdsum(b) == 0
a = ones(5,0)
b = sub(a, :, :)
@test mdsum(b) == 0

0 comments on commit 0314913

Please sign in to comment.