Skip to content

Commit

Permalink
add Flatten iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Mar 23, 2016
1 parent 8972e70 commit fb3dd6e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
54 changes: 54 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,57 @@ function collect{I<:IteratorND}(g::Generator{I})
dest[1] = first
return map_to!(g.f, 2, st, dest, g.iter)
end

# flatten an iterator of iterators

immutable Flatten{I}
it::I
end

"""
flatten(iter)
Given an iterator that yields iterators, return an iterator that yields the
elements of those iterators.
Put differently, the elements of the argument iterator are concatenated. Example:
julia> collect(flatten((1:2, 8:9)))
4-element Array{Int64,1}:
1
2
8
9
"""
flatten(itr) = Flatten(itr)

eltype{I}(::Type{Flatten{I}}) = eltype(eltype(I))

function start(f::Flatten)
local inner, s2
s = start(f.it)
d = done(f.it, s)
# this is a simple way to make this function type stable
d && throw(ArgumentError("argument to Flatten must contain at least one iterator"))
while !d
inner, s = next(f.it, s)
s2 = start(inner)
!done(inner, s2) && break
d = done(f.it, s)
end
return s, inner, s2
end

function next(f::Flatten, state)
s, inner, s2 = state
val, s2 = next(inner, s2)
while done(inner, s2) && !done(f.it, s)
inner, s = next(f.it, s)
s2 = start(inner)
end
return val, (s, inner, s2)
end

@inline function done(f::Flatten, state)
s, inner, s2 = state
return done(f.it, s) && done(inner, s2)
end
13 changes: 13 additions & 0 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ end
@test isempty(collect(Base.product(1:0,1:2)))
@test length(Base.product(1:2,1:10,4:6)) == 60

# flatten
# -------

import Base.flatten

@test collect(flatten(Any[1:2, 4:5])) == Any[1,2,4,5]
@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[10:7, 10:9])])) == Any[1,2]
@test collect(flatten(Any[flatten(Any[1:2, 4:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,4,5,6,7,8,9]
@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9]
@test collect(flatten(Any[2:1])) == Any[]
@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8
@test_throws ErrorException collect(flatten(Any[]))

# foreach
let
a = []
Expand Down

0 comments on commit fb3dd6e

Please sign in to comment.