diff --git a/base/iterator.jl b/base/iterator.jl index 833222ab02e19f..ee8337795c3c92 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -341,3 +341,57 @@ function collect{I<:IteratorND}(g::Generator{I}) dest[1] = first return map_to!(g.f, 2, st, dest, g.iter) end + +# flatten an iterator of iterators + +immutable Flatten{I} + it::I +end + +""" + flatten(iter) + +Given an iterator that yields iterators, return an iterator that yields the +elements of those iterators. +Put differently, the elements of the argument iterator are concatenated. Example: + + julia> collect(flatten((1:2, 8:9))) + 4-element Array{Int64,1}: + 1 + 2 + 8 + 9 +""" +flatten(itr) = Flatten(itr) + +eltype{I}(::Type{Flatten{I}}) = eltype(eltype(I)) + +function start(f::Flatten) + local inner, s2 + s = start(f.it) + d = done(f.it, s) + # this is a simple way to make this function type stable + d && throw(ArgumentError("argument to Flatten must contain at least one iterator")) + while !d + inner, s = next(f.it, s) + s2 = start(inner) + !done(inner, s2) && break + d = done(f.it, s) + end + return s, inner, s2 +end + +function next(f::Flatten, state) + s, inner, s2 = state + val, s2 = next(inner, s2) + while done(inner, s2) && !done(f.it, s) + inner, s = next(f.it, s) + s2 = start(inner) + end + return val, (s, inner, s2) +end + +@inline function done(f::Flatten, state) + s, inner, s2 = state + return done(f.it, s) && done(inner, s2) +end diff --git a/test/functional.jl b/test/functional.jl index f1fcc2d6f95f4b..0c813f8366e6a5 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -166,6 +166,19 @@ end @test isempty(collect(Base.product(1:0,1:2))) @test length(Base.product(1:2,1:10,4:6)) == 60 +# flatten +# ------- + +import Base.flatten + +@test collect(flatten(Any[1:2, 4:5])) == Any[1,2,4,5] +@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[10:7, 10:9])])) == Any[1,2] +@test collect(flatten(Any[flatten(Any[1:2, 4:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,4,5,6,7,8,9] +@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9] +@test collect(flatten(Any[2:1])) == Any[] +@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8 +@test_throws ArgumentError collect(flatten(Any[])) + # foreach let a = []