diff --git a/base/iterator.jl b/base/iterator.jl index 3d71068d8280d..4d94146dc39c5 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -223,3 +223,71 @@ next(it::Repeated, state) = (it.x, nothing) done(it::Repeated, state) = false repeated(x, n::Int) = take(repeated(x), n) + +# product + +abstract AbstractProdIterator + +immutable Prod2{I1, I2} <: AbstractProdIterator + a::I1 + b::I2 +end + +""" + product(iters...) + +Returns an iterator over the product of several iterators. Each generated element is +a tuple whose `i`th element comes from the `i`th argument iterator. The first iterator +changes the fastest. Example: + + julia> collect(product(1:2,3:5)) + 6-element Array{Tuple{Int64,Int64},1}: + (1,3) + (2,3) + (1,4) + (2,4) + (1,5) + (2,5) +""" +product(a) = Zip1(a) +product(a, b) = Prod2(a, b) +eltype{I1,I2}(::Type{Prod2{I1,I2}}) = Tuple{eltype(I1), eltype(I2)} +length(p::AbstractProdIterator) = length(p.a)*length(p.b) + +function start(p::AbstractProdIterator) + s1, s2 = start(p.a), start(p.b) + s1, s2, Nullable{eltype(p.b)}(), (done(p.a,s1) || done(p.b,s2)) +end + +@inline function prod_next(p, st) + s1, s2 = st[1], st[2] + v1, s1 = next(p.a, s1) + + nv2 = st[3] + if isnull(nv2) + v2, s2 = next(p.b, s2) + else + v2 = nv2.value + end + + if done(p.a, s1) + return (v1,v2), (start(p.a), s2, oftype(nv2,nothing), done(p.b,s2)) + end + return (v1,v2), (s1, s2, Nullable(v2), false) +end + +@inline next(p::Prod2, st) = prod_next(p, st) +@inline done(p::AbstractProdIterator, st) = st[4] + +immutable Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator + a::I1 + b::I2 +end + +product(a, b, c...) = Prod(a, product(b, c...)) +eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2)) + +@inline function next{I1,I2}(p::Prod{I1,I2}, st) + x = prod_next(p, st) + ((x[1][1],x[1][2]...), x[2]) +end diff --git a/test/functional.jl b/test/functional.jl index 6c60e62a96191..1e63186762b14 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -155,6 +155,17 @@ let i = 0 end end +# product +# ------- + +@test isempty(Base.product(1:2,1:0)) +@test isempty(Base.product(1:2,1:0,1:10)) +@test isempty(Base.product(1:2,1:10,1:0)) +@test isempty(Base.product(1:0,1:2,1:10)) +@test collect(Base.product(1:2,3:4)) == [(1,3),(2,3),(1,4),(2,4)] +@test isempty(collect(Base.product(1:0,1:2))) +@test length(Base.product(1:2,1:10,4:6)) == 60 + # foreach let a = []