From 0314913de2d8c6aa9027ee880e3367926ea8fe4e Mon Sep 17 00:00:00 2001 From: timholy Date: Sat, 20 Sep 2014 17:17:41 -0500 Subject: [PATCH] Add a multidimensional (cartesian) iterator Closes #1917, closes #6437 --- base/dates/ranges.jl | 2 +- base/exports.jl | 2 + base/multidimensional.jl | 119 +++++++++++++++++++++++++++++++++++++++ base/range.jl | 4 +- base/sharedarray.jl | 8 +++ test/arrayops.jl | 51 +++++++++++++++++ 6 files changed, 183 insertions(+), 3 deletions(-) diff --git a/base/dates/ranges.jl b/base/dates/ranges.jl index 144c3552567bf..719c680705fa4 100644 --- a/base/dates/ranges.jl +++ b/base/dates/ranges.jl @@ -29,5 +29,5 @@ function in{T<:TimeType}(x::T, r::StepRange{T}) end Base.start{T<:TimeType}(r::StepRange{T}) = 0 -Base.next{T<:TimeType}(r::StepRange{T}, i) = (r.start+r.step*i,i+1) +Base.next{T<:TimeType}(r::StepRange{T}, i::Int) = (r.start+r.step*i,i+1) Base.done{T<:TimeType,S<:Period}(r::StepRange{T,S}, i::Integer) = length(r) <= i diff --git a/base/exports.jl b/base/exports.jl index 5b153f7b3ae9b..f42e76b075a4b 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -515,6 +515,8 @@ export cumsum, cumsum!, cumsum_kbn, + eachelement, + eachindex, extrema, fill!, fill, diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 29c451c8ac5fe..034f7cb5c539b 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -1,3 +1,122 @@ +### Multidimensional iterators +module IteratorsMD + +import Base: start, done, next, getindex, setindex! +import Base: @nref, @ncall, @nif, @nexprs + +export eachelement, eachindex, linearindexing, LinearFast + +# Traits for linear indexing +abstract LinearIndexing +immutable LinearFast <: LinearIndexing end +immutable LinearSlow <: LinearIndexing end + +linearindexing(::AbstractArray) = LinearSlow() +linearindexing(::Array) = LinearFast() +linearindexing(::BitArray) = LinearFast() +linearindexing(::Range) = LinearFast() + +# this generates types like this: +# immutable Subscripts_3 <: Subscripts{3} +# I_1::Int +# I_2::Int +# I_3::Int +# end +# they are used as iterator states +# TODO: when tuples get improved, replace with a tuple-based implementation. See #6437. + +abstract Subscripts{N} # the state for all multidimensional iterators +abstract SizeIterator{N} # Iterator that visits the index associated with each element + +function gen_iterators(N::Int, with_shared=true) + # Create the types + namestate = symbol("Subscripts_$N") + namesize = symbol("SizeIterator_$N") + fieldnames = [symbol("I_$i") for i = 1:N] + fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N] + exstate = Expr(:type, false, Expr(:(<:), namestate, Expr(:curly, :Subscripts, N)), Expr(:block, fields...)) + dimsindexes = Expr[:(dims[$i]) for i = 1:N] + onesN = ones(Int, N) + infsN = fill(typemax(Int), N) + anyzero = Expr(:(||), [:(SZ.I.$(fieldnames[i]) == 0) for i = 1:N]...) + # Some necessary ambiguity resolution + exrange = N != 1 ? nothing : quote + next(R::StepRange, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1) + next{T}(R::UnitRange{T}, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1) + end + exshared = !with_shared ? nothing : quote + getindex{T}(S::SharedArray{T,$N}, state::$namestate) = S.s[state] + setindex!{T}(S::SharedArray{T,$N}, v, state::$namestate) = S.s[state] = v + end + quote + $exstate + immutable $namesize <: SizeIterator{$N} + I::$namestate + end + $namestate(dims::NTuple{$N,Int}) = $namestate($(dimsindexes...)) + _eachindex(dims::NTuple{$N,Int}) = $namesize($namestate(dims)) + + start{T}(AT::(AbstractArray{T,$N},LinearSlow)) = isempty(AT[1]) ? $namestate($(infsN...)) : $namestate($(onesN...)) + start(SZ::$namesize) = $anyzero ? $namestate($(infsN...)) : $namestate($(onesN...)) + + $exrange + + @inline function next{T}(A::AbstractArray{T,$N}, state::$namestate) + @inbounds v = A[state] + newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) + v, newstate + end + @inline function next(iter::$namesize, state::$namestate) + newstate = @nif $N d->(getfield(state,d) < getfield(iter.I,d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) + state, newstate + end + + $exshared + getindex{T}(A::AbstractArray{T,$N}, state::$namestate) = @nref $N A d->getfield(state,d) + setindex!{T}(A::AbstractArray{T,$N}, v, state::$namestate) = (@nref $N A d->getfield(state,d)) = v + end +end + +# Ambiguity resolution +done(R::StepRange, I::Subscripts{1}) = getfield(I, 1) > length(R) +done(R::UnitRange, I::Subscripts{1}) = getfield(I, 1) > length(R) + +Base.start(A::AbstractArray) = start((A,linearindexing(A))) +start(::(AbstractArray,LinearFast)) = 1 +done{T,N}(A::AbstractArray{T,N}, I::Subscripts{N}) = getfield(I, N) > size(A, N) +done{N}(iter::SizeIterator{N}, I::Subscripts{N}) = getfield(I, N) > getfield(iter.I, N) + +eachindex(A::AbstractArray) = eachindex(size(A)) + +let implemented = IntSet() +global eachindex +global eachelement +function eachindex{N}(t::NTuple{N,Int}) + if !in(N, implemented) + eval(gen_iterators(N)) + end + _eachindex(t) +end +function eachelement{T,N}(A::AbstractArray{T,N}) + if !in(N, implemented) + eval(gen_iterators(N)) + end + A +end +end + +# Pre-generate for low dimensions +for N = 1:8 + eval(gen_iterators(N, false)) + eval(:(eachindex(t::NTuple{$N,Int}) = _eachindex(t))) + eval(:(eachelement{T}(A::AbstractArray{T,$N}) = A)) +end + +end # IteratorsMD + +using .IteratorsMD + + ### From array.jl @ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...) diff --git a/base/range.jl b/base/range.jl index 1cf7f5f4d431a..7c21201509bff 100644 --- a/base/range.jl +++ b/base/range.jl @@ -235,8 +235,8 @@ copy(r::Range) = r ## iteration start(r::FloatRange) = 0 -next{T}(r::FloatRange{T}, i) = (convert(T, (r.start + i*r.step)/r.divisor), i+1) -done(r::FloatRange, i) = (length(r) <= i) +next{T}(r::FloatRange{T}, i::Int) = (convert(T, (r.start + i*r.step)/r.divisor), i+1) +done(r::FloatRange, i::Int) = (length(r) <= i) # NOTE: For ordinal ranges, we assume start+step might be from a # lifted domain (e.g. Int8+Int8 => Int); use that for iterating. diff --git a/base/sharedarray.jl b/base/sharedarray.jl index 2c22a36cc9ead..22147d11d4294 100644 --- a/base/sharedarray.jl +++ b/base/sharedarray.jl @@ -207,6 +207,14 @@ end convert(::Type{Array}, S::SharedArray) = S.s # # pass through getindex and setindex! - they always work on the complete array unlike DArrays +for N = 1:8 + name = symbol("Subscripts_$N") + @eval begin + getindex{T}(S::SharedArray{T,$N}, I::IteratorsMD.$name) = getindex(S.s, I) + setindex!{T}(S::SharedArray{T,$N}, v, I::IteratorsMD.$name) = setindex!(S.s, v, I) + end +end + getindex(S::SharedArray) = getindex(S.s) getindex(S::SharedArray, I::Real) = getindex(S.s, I) getindex(S::SharedArray, I::AbstractArray) = getindex(S.s, I) diff --git a/test/arrayops.jl b/test/arrayops.jl index 143f7688655e0..5a700fb6a8153 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -925,3 +925,54 @@ end b718cbc = 5 @test b718cbc[1.0] == 5 @test_throws InexactError b718cbc[1.1] + +# Multidimensional iterators +function mdsum(A) + s = 0.0 + for a in eachelement(A) + s += a + end + s +end + +function mdsum2(A) + s = 0.0 + @inbounds for I in eachindex(A) + s += A[I] + end + s +end + +a = [1:5] +@test isa(Base.linearindexing(a), Base.LinearFast) +b = sub(a, :) +@test isa(Base.linearindexing(b), Base.IteratorsMD.LinearSlow) +shp = [5] +for i = 1:10 + A = reshape(a, tuple(shp...)) + @test mdsum(A) == 15 + @test mdsum2(A) == 15 + B = sub(A, ntuple(i, i->Colon())...) + @test mdsum(B) == 15 + @test mdsum2(B) == 15 + unshift!(shp, 1) +end + +a = [1:10] +shp = [2,5] +for i = 2:10 + A = reshape(a, tuple(shp...)) + @test mdsum(A) == 55 + @test mdsum2(A) == 55 + B = sub(A, ntuple(i, i->Colon())...) + @test mdsum(B) == 55 + @test mdsum2(B) == 55 + insert!(shp, 2, 1) +end + +a = ones(0,5) +b = sub(a, :, :) +@test mdsum(b) == 0 +a = ones(5,0) +b = sub(a, :, :) +@test mdsum(b) == 0