From a37bf0fde81a9d714e0829aedb524fb1921cf1f1 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Tue, 26 Jan 2016 13:54:39 -0500 Subject: [PATCH] add `Flatten` iterator --- base/iterator.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++++ test/functional.jl | 13 +++++++++++ 2 files changed, 67 insertions(+) diff --git a/base/iterator.jl b/base/iterator.jl index 4d94146dc39c56..3fa8b899ce3057 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -291,3 +291,57 @@ eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2)) x = prod_next(p, st) ((x[1][1],x[1][2]...), x[2]) end + +# flatten an iterator of iterators + +immutable Flatten{I} + it::I +end + +""" + flatten(iter) + +Given an iterator that yields iterators, return an iterator that yields the +elements of those iterators. +Put differently, the elements of the argument iterator are concatenated. Example: + + julia> collect(flatten((1:2, 8:9))) + 4-element Array{Int64,1}: + 1 + 2 + 8 + 9 +""" +flatten(itr) = Flatten(itr) + +eltype{I}(::Type{Flatten{I}}) = eltype(eltype(I)) + +function start(f::Flatten) + local inner, s2 + s = start(f.it) + d = done(f.it, s) + # this is a simple way to make this function type stable + d && error("argument to Flatten must contain at least one iterator") + while !d + inner, s = next(f.it, s) + s2 = start(inner) + !done(inner, s2) && break + d = done(f.it, s) + end + return s, inner, s2 +end + +function next(f::Flatten, state) + s, inner, s2 = state + val, s2 = next(inner, s2) + while done(inner, s2) && !done(f.it, s) + inner, s = next(f.it, s) + s2 = start(inner) + end + return val, (s, inner, s2) +end + +@inline function done(f::Flatten, state) + s, inner, s2 = state + return done(f.it, s) && done(inner, s2) +end diff --git a/test/functional.jl b/test/functional.jl index 1e63186762b14a..be0f84893c66a0 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -166,6 +166,19 @@ end @test isempty(collect(Base.product(1:0,1:2))) @test length(Base.product(1:2,1:10,4:6)) == 60 +# flatten +# ------- + +import Base.flatten + +@test collect(flatten(Any[1:2, 4:5])) == Any[1,2,4,5] +@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[10:7, 10:9])])) == Any[1,2] +@test collect(flatten(Any[flatten(Any[1:2, 4:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,4,5,6,7,8,9] +@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9] +@test collect(flatten(Any[2:1])) == Any[] +@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8 +@test_throws ErrorException collect(flatten(Any[])) + # foreach let a = []